diff --git a/CMakeLists.txt b/CMakeLists.txt index efa68c9ba243af3c7cdca52b915cc14d307ae89f..1594e798a2ba3f735a28a43ef933d80b3b3f8964 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,7 +54,7 @@ option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) option(WITH_DOUBLE "Compile PaddlePaddle with double precision" OFF) option(WITH_RDMA "Compile PaddlePaddle with RDMA support" OFF) option(WITH_TIMER "Compile PaddlePaddle with stats timer" OFF) -option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler" OFF) +option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF) option(WITH_DOC "Compile PaddlePaddle with documentation" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) @@ -254,6 +254,12 @@ elseif() set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in MKL only now." FORCE) endif() +if (WITH_PROFILER) + find_package(Gperftools REQUIRED) + include_directories(${GPERFTOOLS_INCLUDE_DIR}) + add_definitions(-DWITH_GPERFTOOLS) +endif() + include(generic) # simplify cmake module include(package) # set paddle packages include(ccache) # set ccache for compilation diff --git a/README.md b/README.md index 56d6c10c642787836abb55cb2974bda0b8d22da4..c535e9514e1cac9aff51edfcd9bcdc5d34ccd9fd 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ [![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle) -[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) -[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) +[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) +[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) [![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases) [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE) @@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle. Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle. -### Latest PaddlePaddle Release: [Fluid 1.1.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.1) +### Latest PaddlePaddle Release: [Fluid 1.2.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.2) ### Install Latest Stable Release: ``` # Linux CPU @@ -27,9 +27,9 @@ pip install paddlepaddle # Linux GPU cuda9cudnn7 pip install paddlepaddle-gpu # Linux GPU cuda8cudnn7 -pip install paddlepaddle-gpu==1.1.0.post87 +pip install paddlepaddle-gpu==1.2.0.post87 # Linux GPU cuda8cudnn5 -pip install paddlepaddle-gpu==1.1.0.post85 +pip install paddlepaddle-gpu==1.2.0.post85 # For installation on other platform, refer to http://paddlepaddle.org/ ``` @@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==1.1.0.post85 ## Installation -It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) on our website. +It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/install/index_cn.html) on our website. ## Documentation -We provide [English](http://paddlepaddle.org/documentation/docs/en/1.1/getstarted/index_en.html) and -[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.1/beginners_guide/index.html) documentation. +We provide [English](http://paddlepaddle.org/documentation/docs/en/1.2/getstarted/index_en.html) and +[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/index.html) documentation. - [Deep Learning 101](https://github.com/PaddlePaddle/book) You might want to start from this online interactive book that can run in a Jupyter Notebook. -- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.1/user_guides/howto/training/cluster_howto.html) +- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.2/user_guides/howto/training/cluster_howto.html) You can run distributed training jobs on MPI clusters. -- [Python API](http://paddlepaddle.org/documentation/api/zh/1.1/fluid.html) +- [Python API](http://paddlepaddle.org/documentation/docs/zh/1.2/api_cn/index_cn.html) Our new API enables much shorter programs. -- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.1/advanced_usage/development/contribute_to_paddle.html) +- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.2/advanced_usage/development/contribute_to_paddle/index_cn.html) We appreciate your contributions! diff --git a/cmake/FindGperftools.cmake b/cmake/FindGperftools.cmake new file mode 100644 index 0000000000000000000000000000000000000000..928f573a4fb82391859e334d50e6c8ed0e26aae2 --- /dev/null +++ b/cmake/FindGperftools.cmake @@ -0,0 +1,63 @@ +# Tries to find Gperftools. +# +# Usage of this module as follows: +# +# find_package(Gperftools) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# Gperftools_ROOT_DIR Set this variable to the root installation of +# Gperftools if the module has problems finding +# the proper installation path. +# +# Variables defined by this module: +# +# GPERFTOOLS_FOUND System has Gperftools libs/headers +# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler) +# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers + +find_library(GPERFTOOLS_TCMALLOC + NAMES tcmalloc + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_PROFILER + NAMES profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER + NAMES tcmalloc_and_profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_path(GPERFTOOLS_INCLUDE_DIR + NAMES gperftools/heap-profiler.h + HINTS ${Gperftools_ROOT_DIR}/include) + +set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + Gperftools + DEFAULT_MSG + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +mark_as_advanced( + Gperftools_ROOT_DIR + GPERFTOOLS_TCMALLOC + GPERFTOOLS_PROFILER + GPERFTOOLS_TCMALLOC_AND_PROFILER + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +# create IMPORTED targets +if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc) + add_library(gperftools::tcmalloc UNKNOWN IMPORTED) + set_target_properties(gperftools::tcmalloc PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") + add_library(gperftools::profiler UNKNOWN IMPORTED) + set_target_properties(gperftools::profiler PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_PROFILER} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") +endif() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 4e17ddee73958106d5e2c8c8ea5661acc758518a..51f7a61631d7102b60646abe1c6dd7775692f157 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG) if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) + add_definitions(-DEIGEN_USE_GPU) FIND_PACKAGE(CUDA REQUIRED) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 312fbaa0b3d83c37debe78be82503103eabc0bfa..a8b9dcfcf5eec39af0f59c03b1ed9bd4b71ee7bf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -110,6 +110,14 @@ function(find_fluid_modules TARGET_NAME) endif() endfunction(find_fluid_modules) + +function(common_link TARGET_NAME) + if (WITH_PROFILER) + target_link_libraries(${TARGET_NAME} gperftools::profiler) + endif() +endfunction() + + # find all third_party modules is used for paddle static library # for reduce the dependency when building the inference libs. set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY) @@ -274,6 +282,7 @@ function(cc_library TARGET_NAME) endif() target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) + common_link(${TARGET_NAME}) endif() # cpplint code style @@ -340,6 +349,7 @@ function(cc_binary TARGET_NAME) if(cc_binary_DEPS) target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS}) add_dependencies(${TARGET_NAME} ${cc_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endfunction(cc_binary) @@ -362,6 +372,7 @@ function(cc_test TARGET_NAME) target_link_libraries(${TARGET_NAME} ${win32_deps}) endif(WIN32) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} ${cc_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) @@ -420,6 +431,7 @@ function(nv_binary TARGET_NAME) if(nv_binary_DEPS) target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS}) add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endif() endfunction(nv_binary) @@ -433,6 +445,7 @@ function(nv_test TARGET_NAME) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) @@ -499,6 +512,7 @@ function(hip_binary TARGET_NAME) if(hip_binary_DEPS) target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS}) add_dependencies(${TARGET_NAME} ${hip_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endif() endfunction(hip_binary) @@ -518,6 +532,7 @@ function(hip_test TARGET_NAME) set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) + common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) endif() endfunction(hip_test) @@ -560,6 +575,7 @@ function(go_library TARGET_NAME) endif() if(go_library_DEPS) add_dependencies(${TARGET_NAME} ${go_library_DEPS}) + common_link(${TARGET_NAME}) endif(go_library_DEPS) # The "source file" of the library is `${dummyfile}` which never diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c6026911b37c920c97766f591c579d01b343d74..da3b2f6347487c549c4dee68e5f87b27366262e6 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -67,6 +67,7 @@ paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) +paddle.fluid.layers.bpr_loss ArgSpec(args=['input', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None)) @@ -198,6 +199,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) +paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt index 6b526f0103ad3c530c06a68757cf89293f4fb84b..595454e90b9cd713fd2baed24538cf5fbc93934a 100644 --- a/paddle/fluid/CMakeLists.txt +++ b/paddle/fluid/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(memory) add_subdirectory(platform) add_subdirectory(framework) +add_subdirectory(imperative) add_subdirectory(operators) add_subdirectory(string) add_subdirectory(recordio) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 795f7a2ee3f397224542b61cf3a42abc00ea5b92..9c8bd1518a699a89484bc2381f3fa34336430f25 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -72,6 +72,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) +cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory) + cc_library(reader SRCS reader.cc DEPS lod_tensor ddim) cc_test(reader_test SRCS reader_test.cc DEPS reader) @@ -129,11 +131,13 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -if(NOT WIN32) -cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) -cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog - shape_inference data_transform lod_tensor profiler) -endif(NOT WIN32) +if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) + cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog + shape_inference data_transform lod_tensor profiler ngraph) + endif(NOT WIN32) +endif(WITH_NGRAPH) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) @@ -169,14 +173,20 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - if(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) - else(NOT WIN32) + if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper) + else(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) + endif(NOT WIN32) + else(WITH_NGRAPH) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) - endif(NOT WIN32) + endif(WITH_NGRAPH) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() +target_link_libraries(executor garbage_collector) + cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor graph build_strategy diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index c9e3a8ac1d1e5228725bff49ecc6d91e640dfe57..5467f6d1b23c0058f06387e3da97c4193dd5ca6c 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -151,19 +151,22 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, auto out_format = platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); - void* in_data = GetDataFromTensor(in, in_type); - // output tensor has the same dims as input. Reorder don't change dims out->Resize(in.dims()); - auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); - - auto in_memory = memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data); - auto out_memory = - memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data); + if (in_format != out_format) { + void* in_data = GetDataFromTensor(in, in_type); + auto out_data = out->mutable_data(expected_kernel_type.place_, in.type()); - platform::Reorder(in_memory, out_memory); + auto in_memory = + memory({{{in_tz}, in_type, in_format}, cpu_engine}, in_data); + auto out_memory = + memory({{{out_tz}, out_type, out_format}, cpu_engine}, out_data); + platform::Reorder(in_memory, out_memory); + } else { + out->ShareDataWith(in); + } out->set_layout(out_layout); // reset format since the out tensor will be feed to non-MKLDNN OPkernel out->set_format(memory::format::format_undef); diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 387a99f9b27bf5ab13c28c5ccd828784c418e2c2..8dfb982b77c84a75d9c85ed588eb2d07aeb01f33 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -45,10 +45,10 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) -if (WITH_GPU) - cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle scale_loss_grad_op_handle rpc_op_handle - all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) -endif() +cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle) +cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) +cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass) +cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper) cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) @@ -56,10 +56,7 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass) -if (WITH_GPU) - list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) -endif() +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7ad1e40c600c6e70cea822fac777ff20163078e6..7beb8c8de9fc49aebc66ca44de8736240aabbc30 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -20,11 +20,13 @@ namespace paddle { namespace framework { namespace details { ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, - platform::Place place) + platform::Place place, + size_t scope_idx) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(*node->Op())), scope_(scope), - place_(place) {} + place_(place), + scope_idx_(scope_idx) {} void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index 5346b56dd63d19e3cffedf2e28bba3942abb3856..5b8b70c5641672f3904f657c9a087dc3156ee525 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -29,7 +29,8 @@ namespace framework { namespace details { struct ComputationOpHandle : public OpHandleBase { public: - ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place); + ComputationOpHandle(ir::Node *node, Scope *scope, platform::Place place, + size_t scope_idx); std::string Name() const override; @@ -39,6 +40,8 @@ struct ComputationOpHandle : public OpHandleBase { void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } + size_t GetScopeIdx() const { return scope_idx_; } + protected: void RunImpl() override; @@ -48,6 +51,7 @@ struct ComputationOpHandle : public OpHandleBase { std::unique_ptr op_; Scope *scope_; platform::Place place_; + size_t scope_idx_; bool is_lock_and_record_event_free_{false}; }; } // namespace details diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..abacb11e3b018308c20a67630e3ff34cca7d3387 --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +EagerDeletionOpHandle::EagerDeletionOpHandle( + ir::Node *node, const Scope *scope, const platform::Place &place, + const std::unordered_set &var_names, GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts) + : OpHandleBase(node), + scope_(scope), + var_names_(var_names), + gc_(gc), + ref_cnts_(ref_cnts) { +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place)) { + dev_ctx_ = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(place)); + if (dynamic_cast(gc_)) { + platform::CUDADeviceGuard guard( + boost::get(place).device); + PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + PADDLE_ENFORCE_NOT_NULL(event_); + } + } +#endif +} + +EagerDeletionOpHandle::~EagerDeletionOpHandle() { +#ifdef PADDLE_WITH_CUDA + if (event_) { + auto gpu_place = boost::get(dev_ctx_->GetPlace()); + platform::CUDADeviceGuard guard(gpu_place.device); + PADDLE_ENFORCE(cudaEventDestroy(event_)); + } +#endif +} + +std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } + +void EagerDeletionOpHandle::RunImpl() { + auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); + std::deque> garbages; + for (auto &name : var_names_) { + auto it = ref_cnts_->find(name); + // Var not found, not reference count has not decreased to 0 + if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) { + continue; + } + + auto *var = exec_scope->FindVar(name); + if (var == nullptr) { + continue; + } + + VLOG(2) << "Erase variable " << name; + + if (var->IsType()) { + garbages.emplace_back(var->GetMutable()->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages.emplace_back( + var->GetMutable()->mutable_value()->MoveMemoryHolder()); + } else if (var->IsType()) { + auto *tensor_arr = var->GetMutable(); + for (auto &t : *tensor_arr) { + garbages.emplace_back(t.MoveMemoryHolder()); + } + } else { + PADDLE_THROW("Type %s of %s is not supported eager deletion", + var->Type().name(), name); + } + } + + if (!garbages.empty()) { + ClearGarbages(&garbages); + } +} + +void EagerDeletionOpHandle::ClearGarbages( + std::deque> *garbages) { +#ifdef PADDLE_WITH_CUDA + if (event_) { + auto compute_stream = dev_ctx_->stream(); + auto callback_stream = + reinterpret_cast(gc_)->stream(); + auto callback_func = [=]() { + PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); + PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); + }; + gc_->Add(std::move(*garbages), callback_func); + } else { +#endif + gc_->Add(std::move(*garbages)); +#ifdef PADDLE_WITH_CUDA + } +#endif +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.h b/paddle/fluid/framework/details/eager_deletion_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..64867afad5b70a2ba31e5cb315daffcf433b5935 --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.h @@ -0,0 +1,58 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace details { + +class EagerDeletionOpHandle : public OpHandleBase { + public: + EagerDeletionOpHandle(ir::Node *node, const Scope *scope, + const platform::Place &place, + const std::unordered_set &var_names, + GarbageCollector *gc, + AtomicReferenceCountMap *ref_cnts); + + ~EagerDeletionOpHandle(); + + std::string Name() const override; + + protected: + void RunImpl() override; + + private: + void ClearGarbages(std::deque> *garbages); + + const Scope *scope_; + std::unordered_set var_names_; + GarbageCollector *gc_; // not own + AtomicReferenceCountMap *ref_cnts_; // not own +#ifdef PADDLE_WITH_CUDA + platform::CUDADeviceContext *dev_ctx_{nullptr}; + cudaEvent_t event_{nullptr}; +#endif +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e42d0b4972d567dd769cad6ff8b9d45380ab77a --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_pass.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_pass.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +std::unique_ptr EagerDeletionPass::ApplyImpl( + std::unique_ptr graph) const { + auto &ref_cnts = + Get>(kRuntimeReferenceCount); + PADDLE_ENFORCE(ref_cnts.empty(), + "kRuntimeReferenceCount should be initialized here!"); + + const auto &vars = graph->Get(kGraphVars); + ref_cnts.resize(vars.size()); + + const auto &last_live_ops = + Get>(kLastLiveOpsOfVars); + const auto &gcs = Get(kGarbageCollector); + const auto &places = Get>(kAllPlaces); + + // a reverse map of last_live_ops + // i.e., last op --> variable names which can be deleted. + std::unordered_map> + op_vars_map; + + for (auto &var_ops_map : last_live_ops) { + for (auto &var_ops_pair : var_ops_map) { + const std::string &var_name = var_ops_pair.first; + for (auto *op : var_ops_pair.second) { + op_vars_map[op].insert(var_name); + } + } + } + + for (auto &pair : op_vars_map) { + auto *op = pair.first; + auto &var_names = pair.second; + + auto *eager_deletion_node = + graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); + auto *eager_deletion_op = new EagerDeletionOpHandle( + eager_deletion_node, op->GetScope(), op->GetPlace(), var_names, + gcs.at(places[op->GetScopeIdx()]).get(), + &(ref_cnts[op->GetScopeIdx()])); + + auto it = std::find_if( + op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) { + return dynamic_cast(var) != nullptr; + }); + + if (it != op->Outputs().end()) { + eager_deletion_op->AddInput(*it); + } else { + auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dep_var); + op->AddOutput(dep_var); + eager_deletion_op->AddInput(dep_var); + } + + auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); + graph->Get(kGraphDepVars).emplace(dummy_leaf); + eager_deletion_op->AddOutput(dummy_leaf); + } + + VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)"; + return graph; +} + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(eager_deletion_pass, + paddle::framework::details::EagerDeletionPass) + .RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount) + .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) + .RequirePassAttr(paddle::framework::details::kAllPlaces) + .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/eager_deletion_pass.h b/paddle/fluid/framework/details/eager_deletion_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..d7a7a9709d970841060778806451bc21cb2c7571 --- /dev/null +++ b/paddle/fluid/framework/details/eager_deletion_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace details { + +class EagerDeletionPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 6c4e0e9168a932dcf80e87ce489a751d0db682b2..6e8cf86fcc99d411e67c54a50bb11466158b8612 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -574,7 +574,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), - local_scopes_[dev_id], places_[dev_id])); + local_scopes_[dev_id], places_[dev_id], dev_id)); CreateOpHandleIOs(result, node, dev_id); } @@ -697,8 +697,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; - result->Get(kGraphOps).emplace_back( - new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); + result->Get(kGraphOps).emplace_back(new ComputationOpHandle( + result->CreateOpNode(node->Op()), s, p, scope_idx)); CreateOpHandleIOs(result, node, scope_idx); } } diff --git a/paddle/fluid/framework/details/op_graph_view.cc b/paddle/fluid/framework/details/op_graph_view.cc index 4838c4198ff35ba3fb562f3a7c0563ee60179e3b..d3865c2c2919c2d43521e4f51013e5fa1b10416d 100644 --- a/paddle/fluid/framework/details/op_graph_view.cc +++ b/paddle/fluid/framework/details/op_graph_view.cc @@ -23,6 +23,8 @@ namespace details { OpGraphView::OpGraphView(const std::vector &ops) { Build(ops); } void OpGraphView::Build(const std::vector &ops) { + preceding_ops_.clear(); + pending_ops_.clear(); for (auto &op : ops) { preceding_ops_[op]; pending_ops_[op]; @@ -40,6 +42,7 @@ void OpGraphView::Build(const std::vector &ops) { std::unordered_set OpGraphView::AllOps() const { std::unordered_set ret; + ret.reserve(preceding_ops_.size()); for (auto &pair : preceding_ops_) { ret.insert(pair.first); } diff --git a/paddle/fluid/framework/details/op_graph_view.h b/paddle/fluid/framework/details/op_graph_view.h index afb3e8e59461eeba10d7027fc70b89cc170c1805..77aa02eba56acb3bb20a5c5a55c75af78a3c1c81 100644 --- a/paddle/fluid/framework/details/op_graph_view.h +++ b/paddle/fluid/framework/details/op_graph_view.h @@ -14,7 +14,7 @@ #pragma once -#include +#include #include #include #include @@ -34,6 +34,11 @@ class OpGraphView { bool HasOp(OpHandleBase *op) const; + // Use a visitor to visit all pending ops of op + // Stop when callback returns false + template + bool VisitAllPendingOps(OpHandleBase *op, Callback &&callback) const; + private: void Build(const std::vector &ops); void EnforceHasOp(OpHandleBase *op) const; @@ -44,6 +49,28 @@ class OpGraphView { pending_ops_; }; +template +bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, + Callback &&callback) const { + EnforceHasOp(op); + std::unordered_set visited; + std::queue q; + q.push(op); + do { + op = q.front(); + q.pop(); + for (auto &pending_op : pending_ops_.at(op)) { + if (visited.count(pending_op) == 0) { + visited.insert(pending_op); + if (!callback(pending_op)) { + return false; + } + } + } + } while (!q.empty()); + return true; +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index bd777e41f8588ecce83f84838a8a95431eb81240..c00c5bc2d1b4b78593f99c819b5a3d642150e773 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -18,7 +18,6 @@ #include #include "ThreadPool.h" -#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" namespace paddle { diff --git a/paddle/fluid/framework/details/reference_count_op_handle.h b/paddle/fluid/framework/details/reference_count_op_handle.h deleted file mode 100644 index cc4ccfbdfc720284e683a8f3f59a4aa57a3a9eb1..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/reference_count_op_handle.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/framework/details/op_handle_base.h" -#include "paddle/fluid/framework/garbage_collector.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/framework/tensor.h" - -namespace paddle { -namespace framework { -namespace details { - -using ReferenceCountMap = std::unordered_map; -using AtomicReferenceCountMap = - std::unordered_map>; -using DeviceReferenceCountMap = - std::unordered_map>; -using AtomicDeviceReferenceCountMap = - std::unordered_map>; -using DeviceGarbageCollectorMap = - std::unordered_map>>; - -class ReferenceCountOpHandle : public OpHandleBase { - public: - ReferenceCountOpHandle(ir::Node *node, const Scope *scope, - const platform::CUDAPlace &place, - const std::vector &var_names, - GarbageCollector *gc, - AtomicReferenceCountMap *ref_cnts) - : OpHandleBase(node), scope_(scope), gc_(gc), ref_cnts_(ref_cnts) { - dev_ctx_ = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - if (IsStreamGarabageCollector()) { - platform::SetDeviceId(place.device); - PADDLE_ENFORCE(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); - } - - for (auto &name : var_names) AddVar(name); - } - - ~ReferenceCountOpHandle() { - if (IsStreamGarabageCollector()) { - auto gpu_place = boost::get(dev_ctx_->GetPlace()); - platform::SetDeviceId(gpu_place.device); - PADDLE_ENFORCE(cudaEventDestroy(event_)); - } - } - - std::string Name() const override { return "reference_count"; } - - void AddVar(const std::string &name) { - auto it = var_names_.find(name); - if (it != var_names_.end()) - ++(it->second); - else - var_names_[name] = 1; - } - - protected: - void RunImpl() override { - auto *exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); - std::vector tensors; - for (auto &pair : var_names_) { - auto &name = pair.first; - auto it = ref_cnts_->find(name); - if (it == ref_cnts_->end()) continue; - - auto *var = exec_scope->FindVar(name); - if (var == nullptr) continue; - - if (var->IsType()) { - if (it->second.fetch_sub(pair.second) <= pair.second) { - tensors.emplace_back(var->GetMutable()); - } - } else if (var->IsType()) { - if (it->second.fetch_sub(pair.second) <= pair.second) { - tensors.emplace_back( - var->GetMutable()->mutable_value()); - } - } - } - - if (!tensors.empty()) { - ClearTensors(tensors); - } - } - - private: - void ClearTensors(const std::vector &tensors) { - auto *gc = dynamic_cast *>(gc_); - if (gc != nullptr) { - auto compute_stream = dev_ctx_->stream(); - auto callback_stream = gc->stream(); - auto callback_func = [=]() { - PADDLE_ENFORCE(cudaEventRecord(event_, compute_stream)); - PADDLE_ENFORCE(cudaStreamWaitEvent(callback_stream, event_, 0)); - }; - gc_->Add(tensors, callback_func); - } else { - gc_->Add(tensors); - } - } - - bool IsStreamGarabageCollector() const { - return dynamic_cast *>(gc_) != nullptr; - } - - const Scope *scope_; - platform::CUDADeviceContext *dev_ctx_; - std::unordered_map var_names_; - GarbageCollector *gc_; // not own - AtomicReferenceCountMap *ref_cnts_; // not own - cudaEvent_t event_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 08783fb5f8b18329c9167edb0dac39b7dd42a746..13a042d8e6ed7f18c76387b666d681df0eabd0b5 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -14,187 +14,240 @@ #include #include +#include #include #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_graph_view.h" #include "paddle/fluid/framework/details/reference_count_pass.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { namespace details { -static ComputationOpHandle *FindNextComputationOpHandle(VarHandle *var_in) { - std::queue queue; - queue.push(var_in); - do { - auto *var = queue.front(); - queue.pop(); - for (auto *op : var->PendingOps()) { - auto *compute_op = dynamic_cast(op); - if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { - return compute_op; +// A functor to shrink/remove operators who depend on other operators in a set +class ShrinkDepsOpFunctor { + private: + enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; + + public: + explicit ShrinkDepsOpFunctor(const std::vector &all_ops) + : graph_(all_ops) {} + + template + OpSet operator()(const OpSet &op_set) const { + using KeyType = typename OpSet::key_type; + static_assert( + std::is_base_of::type>::value, + "Key type of OpSet must be OpHandleBase, or derived of OpHandleBase"); + + if (op_set.size() <= 1) return op_set; + std::vector ops(op_set.begin(), op_set.end()); + OpSet ret; + auto rels = GetRelations(ops); + auto not_before = [](RelationShip r) { return r != kBefore; }; + for (size_t i = 0; i < rels.size(); ++i) { + if (std::all_of(rels[i].begin(), rels[i].end(), not_before)) { + ret.emplace(static_cast(ops[i])); } - for (auto *out_var : op->Outputs()) { - queue.push(out_var); + } + return ret; + } + + private: + std::vector> GetRelations( + const std::vector &ops) const { + std::unordered_map op_to_idx; + for (size_t i = 0; i < ops.size(); ++i) { + PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); + op_to_idx[ops[i]] = i; + } + + PADDLE_ENFORCE(op_to_idx.size() == ops.size(), "Duplicate ops"); + + std::vector> ret(ops.size()); + for (auto &e : ret) { + e.assign(ops.size(), kSame); + } + + size_t found_num = ops.size(); + size_t total_num = ops.size() * ops.size(); + auto visitor = [&](OpHandleBase *op, size_t i) { + auto it = op_to_idx.find(op); + if (it != op_to_idx.end()) { + size_t j = it->second; + if (i != j && ret[i][j] == kSame) { + ret[i][j] = kBefore; + ret[j][i] = kAfter; + found_num += 2; + if (found_num == total_num) { + return false; + } + } + } + return true; + }; + + for (size_t i = 0; i < ops.size(); ++i) { + auto sub_visitor = [&, i](OpHandleBase *op) { return visitor(op, i); }; + if (!graph_.VisitAllPendingOps(ops[i], sub_visitor)) { + break; + } + } + + for (size_t i = 0; i < ops.size(); ++i) { + for (size_t j = i + 1; j < ops.size(); ++j) { + if (ret[i][j] != kSame) continue; + ret[i][j] = kNoDeps; + ret[j][i] = kNoDeps; + } + } + + return ret; + } + + const OpGraphView graph_; +}; + +/** + * Find the nearest downstream computation op handle. If the op is a + * computation op, just return itself. + */ +static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( + OpHandleBase *op, size_t scope_idx) { + std::queue q; + std::unordered_set visited; + q.push(op); + do { + auto *op = q.front(); + q.pop(); + auto *compute_op = dynamic_cast(op); + if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { + return compute_op; + } + for (auto *out_var : op->Outputs()) { + for (auto *pending_op : out_var->PendingOps()) { + if (visited.count(pending_op)) continue; + visited.insert(pending_op); } } - } while (!queue.empty()); + } while (!q.empty()); return nullptr; } -static void AddDependencyBetween(OpHandleBase *in, OpHandleBase *out, - ir::Graph *graph) { - auto it = std::find_if( - in->Outputs().begin(), in->Outputs().end(), [](VarHandleBase *var) { - return dynamic_cast(var) != nullptr; - }); - - if (it != in->Outputs().end()) { - out->AddInput(*it); - } else { - auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dep_var); - in->AddOutput(dep_var); - out->AddInput(dep_var); +static std::unordered_set +ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, + const ShrinkDepsOpFunctor &shrink_func, + bool *ok) { + // stage one. Get last op for variable. + std::unordered_set candidates; + { + if (var->PendingOps().empty() && var->GeneratedOp()) { + // No operator depends on this variable. So the last operator is the op + // who generates this variable. + candidates.emplace(var->GeneratedOp()); + } else { + candidates = var->PendingOps(); + } + + // No pending ops or generated op is nullptr + if (candidates.empty()) { + *ok = false; + return {}; + } + } + + // stage two. Try to cast them to computation op. + // return (*ok=false) when failed. + // + // The reason why we cannot make any types of op handle to be the last lived + // op is: + // some op handle may operate on many DeviceContext, however, our garbage + // collector can only wait one DeviceContext for now. So currently, we wait + // the nearest compute op. + std::unordered_set computation_op; + { + for (auto *op : candidates) { + auto *compute_op = + FindNextComputationOpHandleOrReturnItself(op, scope_idx); + if (compute_op == nullptr) { + *ok = false; + return {}; + } + computation_op.emplace(compute_op); + } } + + // stage three. Try to shrink computation op if they depend on each other. + // Get the smallest set of the most ops. + *ok = true; + return shrink_func(computation_op); +} + +static VarDesc *TryGetLatestVarDesc(const std::vector &vars) { + VarDesc *var_desc = nullptr; + std::find_if(vars.rbegin(), vars.rend(), [&](VarHandle *var_handle) -> bool { + var_desc = var_handle->Node()->Var(); + return var_desc != nullptr; + }); + return var_desc; } std::unique_ptr ReferenceCountPass::ApplyImpl( std::unique_ptr graph) const { - auto &ref_cnts = Get(kGlobalReferenceCount); - auto &cur_ref_cnts = Get(kCurReferenceCount); - auto &gcs = Get(kGarbageCollector); - - // It is not easy to find the right reference counts of varaibles in graph - // Step 1: Find all variables in computation ops - // Step 2: Find all variables in non-computation ops which refers to variables - // in computation ops - std::unordered_set names; - std::unordered_map - compute_ref_cnt_map; - - auto get_ref_cnts_from_compute_op = [&]( - OpHandleBase *op, const std::vector &vars) { - std::vector var_names_in_op; - auto *compute_op = dynamic_cast(op); - if (compute_op == nullptr || - !platform::is_gpu_place(compute_op->GetPlace())) - return var_names_in_op; - auto place = boost::get(compute_op->GetPlace()); - for (VarHandleBase *var_handle_base : vars) { - auto *var_handle = dynamic_cast(var_handle_base); - if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; - - if (!platform::is_gpu_place(var_handle->place_) || - boost::get(var_handle->place_) != place) - continue; + auto &ref_cnts = Get>(kGlobalReferenceCount); + auto &last_live_ops_of_vars = + Get>(kLastLiveOpsOfVars); + + PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(), + "Last Live Ops and Reference Counts of vars should be " + "initialized at here."); - VarDesc *var_desc = var_handle->Node()->Var(); - auto var_name = var_handle->Node()->Name(); + const auto &vars = graph->Get(kGraphVars); - // This is weird but there is really some variables without var_desc - // in computation_op - if (var_desc == nullptr) { - var_desc = compute_op->Node()->Op()->Block()->FindVar(var_name); - if (var_desc == nullptr) continue; + last_live_ops_of_vars.resize(vars.size()); + ref_cnts.resize(vars.size()); + + ShrinkDepsOpFunctor shrink_func( + ir::FilterByNodeWrapper(*graph)); + + for (size_t i = 0; i < vars.size(); ++i) { + for (auto &name_var_pair : vars[i]) { + // Whether this variable can be reused or deleted? If not, we do not + // compute reference counts and dependencies. + VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second); + + if (var_desc == nullptr || var_desc->Persistable()) { + continue; } - if (var_desc->Persistable()) continue; auto var_type = var_desc->Proto()->type().type(); if (var_type != proto::VarType::LOD_TENSOR && - var_type != proto::VarType::SELECTED_ROWS) { + var_type != proto::VarType::SELECTED_ROWS && + var_type != proto::VarType::LOD_TENSOR_ARRAY) { + // Var type cannot be deleted continue; } - // compute op only runs in one device - if (ref_cnts[place.device]->count(var_name)) - ++(*ref_cnts[place.device])[var_name]; - else - (*ref_cnts[place.device])[var_name] = 1; + bool ok; + auto result = ExtractComputationOpFromLastLivedVar( + name_var_pair.second.back(), i, shrink_func, &ok); - names.insert(var_name); - var_names_in_op.push_back(var_name); - } - return var_names_in_op; - }; - - auto update_ref_cnts_from_non_compute_op = [&]( - OpHandleBase *op, const std::vector &vars) { - if (dynamic_cast(op) != nullptr) return; - for (VarHandleBase *var_handle_base : vars) { - auto *var_handle = dynamic_cast(var_handle_base); - if (var_handle == nullptr || !var_handle->Node()->IsVar()) continue; - - auto var_name = var_handle->Node()->Name(); - auto var_place = var_handle->place_; - if (!platform::is_gpu_place(var_place)) continue; - auto place = boost::get(var_place); - if (names.count(var_name) == 0) continue; - if (ref_cnts.count(place.device) && - ref_cnts[place.device]->count(var_name)) { - ++(*ref_cnts[place.device])[var_name]; - - auto *next_compute_op = FindNextComputationOpHandle(var_handle); - if (next_compute_op != nullptr) { - if (compute_ref_cnt_map.count(next_compute_op)) { - compute_ref_cnt_map[next_compute_op]->AddVar(var_name); - VLOG(5) << "Add reference count of " << var_name << " to Operator " - << next_compute_op->Name(); - } else { - // Create new reference_count_op_handle - ir::Node *ref_cnt_node = graph->CreateEmptyNode( - "reference_count", ir::Node::Type::kOperation); - auto *ref_cnt_handle = new ReferenceCountOpHandle( - ref_cnt_node, next_compute_op->GetScope(), place, {var_name}, - gcs[place.device].get(), cur_ref_cnts[place.device].get()); - AddDependencyBetween(next_compute_op, ref_cnt_handle, graph.get()); - compute_ref_cnt_map[next_compute_op] = ref_cnt_handle; - } - } + if (ok) { + auto &var_name = name_var_pair.first; + PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", + var_name); + ref_cnts[i].emplace(var_name, result.size()); + last_live_ops_of_vars[i].emplace(var_name, std::move(result)); } } - }; - - auto all_ops = ir::FilterByNodeWrapper(*graph); - for (auto &op : all_ops) { - auto in_var_names = get_ref_cnts_from_compute_op(op, op->Inputs()); - auto out_var_names = get_ref_cnts_from_compute_op(op, op->Outputs()); - if (in_var_names.empty() && out_var_names.empty()) continue; - in_var_names.insert(in_var_names.end(), out_var_names.begin(), - out_var_names.end()); - auto *compute_op = dynamic_cast(op); - auto place = boost::get(compute_op->GetPlace()); - ir::Node *ref_cnt_node = - graph->CreateEmptyNode("reference_count", ir::Node::Type::kOperation); - auto *ref_cnt_handle = new ReferenceCountOpHandle( - ref_cnt_node, compute_op->GetScope(), place, in_var_names, - gcs[place.device].get(), cur_ref_cnts[place.device].get()); - AddDependencyBetween(compute_op, ref_cnt_handle, graph.get()); - compute_ref_cnt_map[compute_op] = ref_cnt_handle; - } - - for (auto &op : all_ops) { - update_ref_cnts_from_non_compute_op(op, op->Inputs()); - update_ref_cnts_from_non_compute_op(op, op->Outputs()); - } - - std::vector new_all_ops; - new_all_ops.reserve(compute_ref_cnt_map.size() + all_ops.size()); - for (auto &op : all_ops) { - new_all_ops.emplace_back(std::move(op)); - auto it = compute_ref_cnt_map.find(new_all_ops.back()); - if (it != compute_ref_cnt_map.end()) { - // Add LeafNode to ReferenceCountOpHandle - auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); - graph->Get(kGraphDepVars).emplace(dummy_leaf); - it->second->AddOutput(dummy_leaf); - new_all_ops.emplace_back(std::move(it->second)); - } } - all_ops.swap(new_all_ops); return graph; } @@ -205,5 +258,4 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( REGISTER_PASS(reference_count_pass, paddle::framework::details::ReferenceCountPass) .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount) - .RequirePassAttr(paddle::framework::details::kCurReferenceCount) - .RequirePassAttr(paddle::framework::details::kGarbageCollector); + .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars); diff --git a/paddle/fluid/framework/details/reference_count_pass.h b/paddle/fluid/framework/details/reference_count_pass.h index 7081280b0600b9c1985987d02d679c298ad4b8bd..bcbef027354ef5a5fcc7da28103a9565982c7631 100644 --- a/paddle/fluid/framework/details/reference_count_pass.h +++ b/paddle/fluid/framework/details/reference_count_pass.h @@ -14,7 +14,6 @@ #pragma once -#include "paddle/fluid/framework/details/reference_count_op_handle.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" @@ -22,10 +21,6 @@ namespace paddle { namespace framework { namespace details { -constexpr char kGlobalReferenceCount[] = "reference_count"; -constexpr char kCurReferenceCount[] = "current_reference_count"; -constexpr char kGarbageCollector[] = "garbage_collector"; - class ReferenceCountPass : public ir::Pass { protected: std::unique_ptr ApplyImpl( diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.cc b/paddle/fluid/framework/details/reference_count_pass_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..89bd08c2d041d795205b29bb29aba311d1dbd932 --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass_helper.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" + +namespace paddle { +namespace framework { +namespace details {} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.h b/paddle/fluid/framework/details/reference_count_pass_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..1c083dbf001b08e40a54cc89b21c3dea1f18f16a --- /dev/null +++ b/paddle/fluid/framework/details/reference_count_pass_helper.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/garbage_collector.h" + +namespace paddle { +namespace framework { +namespace details { + +class ComputationOpHandle; + +using ReferenceCountMap = std::unordered_map; + +using AtomicReferenceCountMap = + std::unordered_map>; + +using GarbageCollectorMap = + std::map>; + +const char kGlobalReferenceCount[] = "global_reference_count"; +const char kRuntimeReferenceCount[] = "runtime_reference_count"; +const char kGarbageCollector[] = "garbage_collector"; +const char kAllPlaces[] = "all_places"; + +using LastLiveOpsOfVars = + std::unordered_map>; +const char kLastLiveOpsOfVars[] = "last_live_ops_of_var"; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 85898af417e39358bd63fd10a5faffcd7d88fc5d..edb7b5e70ac0a082fac9906bb2fa6bc2064ffde9 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -18,9 +18,6 @@ #include #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/profiler.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/reference_count_op_handle.h" -#endif namespace paddle { namespace framework { @@ -67,31 +64,12 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr); drop_scope_counter_ += 1; -#ifdef PADDLE_WITH_CUDA - const std::string gc_name = "garbage_collector"; - DeviceGarbageCollectorMap *gc = nullptr; - // FIXME(Yancey1989): need to fix gc failed on parallel graph mode - if (strategy_.type_ != ExecutionStrategy::kParallelGraph) { - gc = Graph().Has(gc_name) - ? &(Graph().Get(gc_name)) - : nullptr; - } -#endif - if (!fetch_tensors.empty() || drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); -#ifdef PADDLE_WITH_CUDA - if (gc != nullptr && platform::is_gpu_place(p)) { - auto gpu_place = boost::get(p); - auto &gc_at_place = gc->at(gpu_place.device); - gc_at_place->Wait(); - gc_at_place->Reset(); - } -#endif } for (auto &scope : local_scopes_) { auto &local_scope = diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 73cec21e20f2fd26e144872f1f7b5bb7065adb74..0c4bd336c5b36eff7d93e4099c96cd3c5990ac17 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/executor.h" +#include #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" -#include "paddle/fluid/framework/ngraph_operator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/transfer_scope_cache.h" @@ -26,6 +26,10 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_NGRAPH +#include "paddle/fluid/framework/ngraph_operator.h" +#endif + DECLARE_bool(benchmark); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_ngraph, false, "Use NGRAPH to run"); @@ -38,11 +42,43 @@ namespace { int kProgramId = -1; } // namespace +static std::unordered_map GetNonPersistableReferenceCounts( + const BlockDesc& block, const std::vector& skip_var_list) { + std::unordered_map ref_cnts; + std::unordered_set skip_vars(skip_var_list.begin(), + skip_var_list.end()); + + auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) { + for (auto& name_pair : name_map) { + for (auto& name : name_pair.second) { + if (skip_vars.count(name)) continue; + auto* var_desc = block.FindVar(name); + if (var_desc == nullptr || var_desc->Persistable()) continue; + auto type = var_desc->Proto()->type().type(); + if (type != proto::VarType::LOD_TENSOR && + type != proto::VarType::SELECTED_ROWS && + type != proto::VarType::LOD_TENSOR_ARRAY) { + continue; + } + ++ref_cnts[name]; + } + } + }; + + for (auto op_desc : block.AllOps()) { + update_ref_cnts(op_desc, op_desc->Inputs()); + update_ref_cnts(op_desc, op_desc->Outputs()); + } + return ref_cnts; +} + ExecutorPrepareContext::ExecutorPrepareContext( - const framework::ProgramDesc& prog, size_t block_id) + const framework::ProgramDesc& prog, size_t block_id, + const std::vector& skip_ref_cnt_vars) : prog_(prog), block_id_(block_id) { if (GetEagerDeletionThreshold() >= 0) { - ref_cnts_ = GetNonPersistableReferenceCount(prog_, block_id_); + global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), + skip_ref_cnt_vars); } } @@ -50,28 +86,40 @@ ExecutorPrepareContext::~ExecutorPrepareContext() { VLOG(5) << "destroy ExecutorPrepareContext"; } -template -static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, - GarbageCollector* gc, - RefCntMap* ref_cnts) { - std::unordered_set erase_tensors; +static void DeleteUnusedTensors( + const Scope& scope, const OperatorBase* op, GarbageCollector* gc, + std::unordered_map* ref_cnts) { + std::deque> garbages; auto handler = [&](const VariableNameMap& name_map) { for (auto& name_pair : name_map) { for (auto& name : name_pair.second) { auto it = ref_cnts->find(name); if (it == ref_cnts->end()) continue; - if ((it->second)-- == 1) { - auto* var = scope.FindVar(name); - if (var != nullptr) { - VLOG(10) << "Erase tensor \'" << name << "\'"; - if (var->IsType()) { - erase_tensors.insert(var->GetMutable()); - } else if (var->IsType()) { - erase_tensors.insert( - var->GetMutable()->mutable_value()); - } + if (--(it->second) != 0) { + continue; + } + auto* var = scope.FindVar(name); + if (var != nullptr) { + continue; + } + + VLOG(2) << "Erase variable " << name; + if (var->IsType()) { + garbages.emplace_back( + var->GetMutable()->MoveMemoryHolder()); + } else if (var->IsType()) { + garbages.emplace_back(var->GetMutable() + ->mutable_value() + ->MoveMemoryHolder()); + } else if (var->IsType()) { + auto* lod_tensor_arr = var->GetMutable(); + for (auto& t : *lod_tensor_arr) { + garbages.emplace_back(t.MoveMemoryHolder()); } + } else { + PADDLE_THROW("Type %s of %s is not supported eager deletion", + var->Type().name(), name); } } } @@ -80,19 +128,19 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, handler(op->Inputs()); handler(op->Outputs()); - if (!erase_tensors.empty()) { - gc->Add(erase_tensors); + if (!garbages.empty()) { + gc->Add(std::move(garbages)); } } static void EnableFusedOp(ExecutorPrepareContext* ctx) { #ifdef PADDLE_WITH_NGRAPH VLOG(3) << "use_ngraph=True"; - auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_); + auto intervals = NgraphOperator::NgraphOpIntervals(&ctx->ops_); for (auto& interval : intervals) { - auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_, - interval.at(0), interval.at(1)); - *interval[0] = std::unique_ptr(fused_op); + auto* ng_op = new NgraphOperator(ctx->prog_, ctx->block_id_, interval.at(0), + interval.at(1)); + *interval[0] = std::unique_ptr(ng_op); } for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { ctx->ops_.erase(it->at(0) + 1, it->at(1)); @@ -322,9 +370,10 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } std::unique_ptr Executor::Prepare( - const ProgramDesc& program, int block_id) { + const ProgramDesc& program, int block_id, + const std::vector& skip_ref_cnt_vars) { std::unique_ptr ctx( - new ExecutorPrepareContext(program, block_id)); + new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars)); PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); auto& block = program.Block(block_id); for (auto& op_desc : block.AllOps()) { @@ -335,16 +384,28 @@ std::unique_ptr Executor::Prepare( } std::vector> Executor::Prepare( - const ProgramDesc& program, const std::vector& block_ids) { + const ProgramDesc& program, const std::vector& block_ids, + const std::vector>& skip_ref_cnt_vars) { + PADDLE_ENFORCE( + skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), + "skip_ref_cnt_vars should be either empty or equals to block number %d", + block_ids.size()); std::vector> result; + size_t idx = 0; for (auto& bid : block_ids) { - auto* ctx = new ExecutorPrepareContext(program, bid); + ExecutorPrepareContext* ctx; + if (skip_ref_cnt_vars.empty()) { + ctx = new ExecutorPrepareContext(program, bid); + } else { + ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]); + } PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); auto& block = program.Block(bid); for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } result.push_back(std::shared_ptr(ctx)); + ++idx; } return result; } @@ -362,22 +423,23 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } int64_t max_memory_size = GetEagerDeletionThreshold(); - std::unique_ptr> gc; - // WhileOp would set keep_kids to true, - // because WhileGradOp needs the scopes created in WhileOp. - // Perhaps, we should not perform eager deletion in WhileOp - // The scopes and variables created by WhileOp would be deleted - // in WhileGradOp. + std::unique_ptr gc; + // skip while_op and while_grad_op temporarily if (max_memory_size >= 0 && !keep_kids) { ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { - gc.reset(new DefaultStreamGarbageCollector( - boost::get(place_), max_memory_size)); - } else { + if (IsFastEagerDeletionModeEnabled()) { + gc.reset(new UnsafeFastGPUGarbageCollector( + boost::get(place_), max_memory_size)); + } else { + gc.reset(new DefaultStreamGarbageCollector( + boost::get(place_), max_memory_size)); + } + } else if (platform::is_cpu_place(place_)) { #endif - gc.reset(new CPUGarbageCollector( - boost::get(place_), max_memory_size)); + gc.reset(new CPUGarbageCollector(boost::get(place_), + max_memory_size)); #ifdef PADDLE_WITH_CUDA } #endif @@ -386,17 +448,13 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, for (auto& op : ctx->ops_) { op->Run(*local_scope, place_); - if (gc != nullptr) { + if (gc) { DeleteUnusedTensors(*local_scope, op.get(), gc.get(), - &(ctx->cur_ref_cnts_)); + &(ctx->runtime_ref_cnts_)); } } - if (gc != nullptr) { - gc->Wait(); - } else { - platform::DeviceContextPool::Instance().Get(place_)->Wait(); - } + platform::DeviceContextPool::Instance().Get(place_)->Wait(); if (local_scope != scope) { scope->DeleteScope(local_scope); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 2d47903ffbd8d821b7c31386b225fe5e65ca2720..5a040ac641588ad4d89d1f6e4c0d6c296eff38eb 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -27,52 +27,21 @@ limitations under the License. */ namespace paddle { namespace framework { -template -std::unordered_map GetNonPersistableReferenceCount( - const ProgramDesc& prog, size_t block_id) { - auto& block = prog.Block(block_id); - std::unordered_map ref_cnts; - - auto update_ref_cnts = [&](OpDesc* op_desc, const VariableNameMap& name_map) { - for (auto& name_pair : name_map) { - for (auto& name : name_pair.second) { - auto* var_desc = block.FindVar(name); - if (var_desc == nullptr || var_desc->Persistable()) continue; - auto type = var_desc->Proto()->type().type(); - if (type != proto::VarType::LOD_TENSOR && - type != proto::VarType::SELECTED_ROWS) { - continue; - } - - auto it = ref_cnts.find(name); - if (it != ref_cnts.end()) { - ++it->second; - } else { - ref_cnts[name] = 1; - } - } - } - }; - - for (auto op_desc : block.AllOps()) { - update_ref_cnts(op_desc, op_desc->Inputs()); - update_ref_cnts(op_desc, op_desc->Outputs()); - } - return ref_cnts; -} - struct ExecutorPrepareContext { - ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); + ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, + const std::vector& skip_ref_cnt_vars = + std::vector()); + ~ExecutorPrepareContext(); - void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } + void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; } const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; - std::unordered_map ref_cnts_; - std::unordered_map cur_ref_cnts_; + std::unordered_map global_ref_cnts_; + std::unordered_map runtime_ref_cnts_; }; class Executor { @@ -108,10 +77,14 @@ class Executor { const std::string& fetch_holder_name = "fetch"); static std::unique_ptr Prepare( - const ProgramDesc& program, int block_id); + const ProgramDesc& program, int block_id, + const std::vector& skip_ref_cnt_vars = + std::vector()); static std::vector> Prepare( - const ProgramDesc& program, const std::vector& block_ids); + const ProgramDesc& program, const std::vector& block_ids, + const std::vector>& skip_ref_cnt_vars = + std::vector>()); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); diff --git a/paddle/fluid/framework/feed_fetch_method.cc b/paddle/fluid/framework/feed_fetch_method.cc index 3e9353f5cf67d8de62c5551f12ea786e49190549..6338be75a4b1d3c4caf7a6f7add4d05fec690340 100644 --- a/paddle/fluid/framework/feed_fetch_method.cc +++ b/paddle/fluid/framework/feed_fetch_method.cc @@ -16,7 +16,9 @@ limitations under the License. */ #include #include #include "glog/logging.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/platform/place.h" namespace paddle { namespace framework { @@ -53,5 +55,12 @@ LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, return tensor; } +LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name) { + Variable* var = scope.FindVar(var_name); + PADDLE_ENFORCE(var, "%s no in scope", var_name); + PADDLE_ENFORCE(var->IsType(), "Only support lod tensor now."); + return *var->GetMutable(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/feed_fetch_method.h b/paddle/fluid/framework/feed_fetch_method.h index 7f504bfd232862c014cb59b6e8301eec74e0351f..031f8e01aa6128b803dcbfb990778e87d4fafc13 100644 --- a/paddle/fluid/framework/feed_fetch_method.h +++ b/paddle/fluid/framework/feed_fetch_method.h @@ -27,5 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input, LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, size_t index); +LoDTensor& GetVariableTensor(const Scope& scope, const std::string& var_name); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/garbage_collector.cc b/paddle/fluid/framework/garbage_collector.cc new file mode 100644 index 0000000000000000000000000000000000000000..54d9d0dc018b08decb2ff8965659bab98e81f3ab --- /dev/null +++ b/paddle/fluid/framework/garbage_collector.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_device_guard.h" +#endif +#include "paddle/fluid/framework/garbage_collector.h" + +namespace paddle { +namespace framework { + +GarbageCollector::GarbageCollector(const platform::Place &place, + size_t max_memory_size) + : max_memory_size_((std::max)(max_memory_size, static_cast(1))) { + garbages_.reset(new GarbageQueue()); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); +} + +CPUGarbageCollector::CPUGarbageCollector(const platform::CPUPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void CPUGarbageCollector::ClearCallback(const std::function &callback) { + callback(); +} + +#ifdef PADDLE_WITH_CUDA +UnsafeFastGPUGarbageCollector::UnsafeFastGPUGarbageCollector( + const platform::CUDAPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void UnsafeFastGPUGarbageCollector::ClearCallback( + const std::function &callback) { + callback(); +} + +DefaultStreamGarbageCollector::DefaultStreamGarbageCollector( + const platform::CUDAPlace &place, size_t max_memory_size) + : GarbageCollector(place, max_memory_size) {} + +void DefaultStreamGarbageCollector::Wait() const { + static_cast(this->dev_ctx_) + ->WaitStreamCallback(); +} + +void DefaultStreamGarbageCollector::ClearCallback( + const std::function &callback) { + static_cast(this->dev_ctx_) + ->AddStreamCallback(callback); +} + +StreamGarbageCollector::StreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size) + : GarbageCollector(place, max_memory_size) { + platform::CUDADeviceGuard guard(place.device); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + callback_manager_.reset(new platform::StreamCallbackManager(stream_)); +} + +StreamGarbageCollector::~StreamGarbageCollector() { + auto place = boost::get(this->dev_ctx_->GetPlace()); + platform::CUDADeviceGuard guard(place.device); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); +} + +cudaStream_t StreamGarbageCollector::stream() const { return stream_; } + +void StreamGarbageCollector::Wait() const { callback_manager_->Wait(); } + +void StreamGarbageCollector::ClearCallback( + const std::function &callback) { + callback_manager_->AddCallback(callback); +} +#endif +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/garbage_collector.h b/paddle/fluid/framework/garbage_collector.h index 818b3334ea4171fd7a9cbaa896ee1672e8ecca51..2768671029c06562aa0d2e5eea3d3ff61d900ab5 100644 --- a/paddle/fluid/framework/garbage_collector.h +++ b/paddle/fluid/framework/garbage_collector.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include #include @@ -24,134 +23,74 @@ namespace paddle { namespace framework { -// T should have memory_size() and clear() method -template class GarbageCollector { public: - GarbageCollector(const platform::Place &place, size_t max_memory_size) - : max_memory_size_((std::max)(max_memory_size, static_cast(1))) { - garbages_.reset(new std::deque()); - dev_ctx_ = platform::DeviceContextPool::Instance().Get(place); - } + using GarbageQueue = std::deque>; - virtual ~GarbageCollector() {} + GarbageCollector(const platform::Place &place, size_t max_memory_size); - void Reset() { - std::lock_guard guard(mutex_); - garbages_.reset(new std::deque()); - cur_memory_size_ = 0; - } + virtual ~GarbageCollector() = default; + + virtual void Wait() const {} template - void Add(const Container &objs) { - Add(objs, []() {}); - } + void Add(Container &&objs); template - void Add(const Container &objs, Callback &&callback) { - std::shared_ptr> clear_deque; - { - std::lock_guard guard(mutex_); - for (auto *obj : objs) { - garbages_->push_back(obj); - cur_memory_size_ += obj->memory_size(); - } - if (cur_memory_size_ >= max_memory_size_) { - cur_memory_size_ = 0; - clear_deque = garbages_; - garbages_.reset(new std::deque()); - } - } - - if (clear_deque != nullptr) { - callback(); - ClearCallback([=]() { - for (auto *obj : *clear_deque) obj->clear(); - }); - } - } - - virtual void Wait() const {} + void Add(Container &&objs, Callback &&callback); protected: virtual void ClearCallback(const std::function &callback) = 0; platform::DeviceContext *dev_ctx_; - std::shared_ptr> garbages_; + std::unique_ptr garbages_; mutable std::mutex mutex_; const size_t max_memory_size_; - size_t cur_memory_size_ = 0; + size_t cur_memory_size_{0}; }; -template -class CPUGarbageCollector : public GarbageCollector { +class CPUGarbageCollector : public GarbageCollector { public: - CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size) - : GarbageCollector(place, max_memory_size) {} + CPUGarbageCollector(const platform::CPUPlace &place, size_t max_memory_size); protected: - void ClearCallback(const std::function &callback) override { - callback(); - } + void ClearCallback(const std::function &callback) override; }; #ifdef PADDLE_WITH_CUDA -template -class DefaultStreamGarbageCollector : public GarbageCollector { +class UnsafeFastGPUGarbageCollector : public GarbageCollector { public: - DefaultStreamGarbageCollector(const platform::CUDAPlace &place, - size_t max_memory_size) - : GarbageCollector(place, max_memory_size) {} + UnsafeFastGPUGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size); - cudaStream_t stream() const { - return static_cast(this->dev_ctx_) - ->stream(); - } + protected: + void ClearCallback(const std::function &callback) override; +}; - void Wait() const override { - this->dev_ctx_->Wait(); - static_cast(this->dev_ctx_) - ->WaitStreamCallback(); - } +class DefaultStreamGarbageCollector : public GarbageCollector { + public: + DefaultStreamGarbageCollector(const platform::CUDAPlace &place, + size_t max_memory_size); + + void Wait() const override; protected: - void ClearCallback(const std::function &callback) override { - static_cast(this->dev_ctx_) - ->AddStreamCallback(callback); - } + void ClearCallback(const std::function &callback) override; }; -template -class StreamGarbageCollector : public GarbageCollector { +class StreamGarbageCollector : public GarbageCollector { public: StreamGarbageCollector(const platform::CUDAPlace &place, - size_t max_memory_size) - : GarbageCollector(place, max_memory_size) { - PADDLE_ENFORCE(cudaSetDevice(place.device)); - PADDLE_ENFORCE(cudaStreamCreate(&stream_)); - callback_manager_.reset(new platform::StreamCallbackManager(stream_)); - } + size_t max_memory_size); - ~StreamGarbageCollector() { - auto place = boost::get(this->dev_ctx_->GetPlace()); - PADDLE_ENFORCE(cudaSetDevice(place.device)); - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); - PADDLE_ENFORCE(cudaStreamDestroy(stream_)); - } + ~StreamGarbageCollector(); - void Wait() const override { - PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); - std::lock_guard guard(this->mutex_); - callback_manager_->Wait(); - } + void Wait() const override; - cudaStream_t stream() const { return stream_; } + cudaStream_t stream() const; protected: - void ClearCallback(const std::function &callback) override { - std::lock_guard guard(this->mutex_); - callback_manager_->AddCallback(callback); - } + void ClearCallback(const std::function &callback) override; private: cudaStream_t stream_; @@ -159,5 +98,33 @@ class StreamGarbageCollector : public GarbageCollector { }; #endif +template +void GarbageCollector::Add(Container &&objs) { + Add(std::forward(objs), []() {}); +} + +template +void GarbageCollector::Add(Container &&objs, Callback &&callback) { + GarbageQueue *garbage_queue = nullptr; + { + std::lock_guard guard(mutex_); + for (auto &obj : objs) { + if (!obj) continue; + cur_memory_size_ += obj->size(); + garbages_->push_back(std::move(obj)); + } + if (cur_memory_size_ >= max_memory_size_) { + cur_memory_size_ = 0; + garbage_queue = garbages_.release(); + garbages_.reset(new GarbageQueue()); + } + } + + if (garbage_queue) { + callback(); + ClearCallback([garbage_queue]() { delete garbage_queue; }); + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index fc91564bbaecf7b1725908fc1eb8b1e4d2e20d32..8679118fe28b1c68aea30caf711441823b5255c0 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -38,9 +38,8 @@ void CheckProgram(const ProgramDesc &program) { switch (role_id) { case _INT(OpRole::kForward): if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { - LOG(ERROR) - << "Cannot add backward operator before forward operator %s." - << op->Type(); + LOG(ERROR) << "Cannot add backward operator before forward operator " + << op->Type(); } break; case _INT(OpRole::kBackward): diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index bb2d953afbd64dc33d735f483a8eafd9ce951530..47fcf96a3f92b1f915e5254fff36feb8b2870730 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -73,14 +73,21 @@ class Graph { } bool Has(const std::string &attr_name) const { - return attrs_.find(attr_name) != attrs_.end(); + return attrs_.count(attr_name) > 0; } template AttrType &Get(const std::string &attr_name) const { PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.", attr_name); - return *boost::any_cast(attrs_.at(attr_name)); + try { + return *boost::any_cast(attrs_.at(attr_name)); + } catch (boost::bad_any_cast &) { + PADDLE_THROW( + "Invalid attribute type of %s error, expected: %s, actual: %s", + attr_name, typeid(AttrType *).name(), + attrs_.at(attr_name).type().name()); + } } template diff --git a/paddle/fluid/framework/ir/is_test_pass.cc b/paddle/fluid/framework/ir/is_test_pass.cc index 6d8f020918d4e56fa7f125a659f7f8511ca067ca..57cc98e2ca0175848aa62c62c8ad3b20594b3bde 100644 --- a/paddle/fluid/framework/ir/is_test_pass.cc +++ b/paddle/fluid/framework/ir/is_test_pass.cc @@ -38,7 +38,7 @@ std::unique_ptr IsTestPass::ApplyImpl( for (const Node* n : graph->Nodes()) { if (n->IsOp()) { auto* op = n->Op(); - if (n->RuntimeHasAttr("is_test")) { + if (op->HasAttr("is_test") || op->HasProtoAttr("is_test")) { op->SetAttr("is_test", true); } else if (std::find(begin(op_list), end(op_list), op->Type()) != end(op_list)) { diff --git a/paddle/fluid/framework/ir/is_test_pass_tester.cc b/paddle/fluid/framework/ir/is_test_pass_tester.cc index d9a68c7f1dd2a0dca5204719c4ce6cd9d68292a2..9696441a21661db89146c448742a992d1f7df022 100644 --- a/paddle/fluid/framework/ir/is_test_pass_tester.cc +++ b/paddle/fluid/framework/ir/is_test_pass_tester.cc @@ -104,9 +104,9 @@ TEST(IsTestPass, basic) { auto* op = node->Op(); auto op_name = boost::get(op->GetAttr("name")); if (op_name == "conv3") { - ASSERT_FALSE(node->RuntimeHasAttr("is_test")); + ASSERT_FALSE(op->HasAttr("is_test")); } else { - ASSERT_TRUE(node->RuntimeHasAttr("is_test")); + ASSERT_TRUE(op->HasAttr("is_test")); EXPECT_TRUE(boost::get(op->GetAttr("is_test"))); } } diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc index 9a9314161b0e8d14a525d253572915ed597c716e..951fcb066ce759ebfec0182e1e9dca887e343170 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -25,12 +25,15 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( const auto& op_types_list = Get>("mkldnn_enabled_op_types"); for (const Node* n : graph->Nodes()) { - if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { - if (op_types_list.empty()) { - n->Op()->SetAttr("use_mkldnn", true); - } else if (std::find(op_types_list.begin(), op_types_list.end(), - n->Name()) != op_types_list.end()) { - n->Op()->SetAttr("use_mkldnn", true); + if (n->IsOp()) { + auto* op = n->Op(); + if (op->HasAttr("use_mkldnn") || op->HasProtoAttr("use_mkldnn")) { + if (op_types_list.empty()) { + op->SetAttr("use_mkldnn", true); + } else if (std::find(op_types_list.begin(), op_types_list.end(), + n->Name()) != op_types_list.end()) { + op->SetAttr("use_mkldnn", true); + } } } } diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index 7a88cb2b681c1aa5e1b75481b1aba66a125a1d9c..eac67108e2106e986cbe1255a64c956153bc5560 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -30,28 +30,6 @@ std::unique_ptr CreateNodeForTest(const std::string &name, return std::unique_ptr(new Node(name, type)); } -bool Node::RuntimeHasAttr(const std::string &name) const { - if (Op()->HasAttr(name)) { - return true; - } else { - auto &op_info = OpInfoMap::Instance(); - auto op_type = Op()->Type(); - if (op_info.Has(op_type)) { - auto op_info_ptr = op_info.Get(op_type); - if (op_info_ptr.HasOpProtoAndChecker()) { - const proto::OpProto &proto = op_info_ptr.Proto(); - for (int i = 0; i != proto.attrs_size(); ++i) { - const proto::OpProto::Attr &attr = proto.attrs(i); - if (attr.name() == name) { - return true; - } - } - } - } - } - return false; -} - } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 1044a96430f060b750580ea0b225787ba6ebd2a0..d2a393b3f19e9aab79098757dae663d030b0fa2b 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -108,18 +108,6 @@ class Node { Name().find(ir::Node::kControlDepVarName) != std::string::npos; } - // RuntimeHasAttr is different with HasAttr now. - // 1. For Op()->HasAttr(), it judges whether a stored program_desc_ has attr, - // thus, if stored program_desc_ are old which don't have an attr, a new - // library which adds the attr already will fail on this function. - // Details: - // https://github.com/PaddlePaddle/Paddle/pull/14608#issuecomment-442309087 - // 2. For Op()->RuntimeHasAttr, it judges the attr in runtime to avoid above - // problem. - // TODO(luotao): Maybe we should enhance HasAttr later, instead of adding - // RuntimeHasAttr. - bool RuntimeHasAttr(const std::string& name) const; - std::vector inputs; std::vector outputs; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index a3559247db6703d486ed01ce9f2058e671443096..27746ff1453b1b336da8c31497c066c338843b68 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -51,11 +51,18 @@ class Pass { AttrType &Get(const std::string &attr_name) const { PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), "%s attr not registered for pass.", attr_name); - return *boost::any_cast(attrs_.at(attr_name)); + try { + return *boost::any_cast(attrs_.at(attr_name)); + } catch (boost::bad_any_cast &) { + PADDLE_THROW( + "Invalid attribute type of %s error, expected: %s, actual: %s", + attr_name, typeid(AttrType *).name(), + attrs_.at(attr_name).type().name()); + } } bool Has(const std::string &attr_name) const { - return attrs_.find(attr_name) != attrs_.end(); + return attrs_.count(attr_name) > 0; } void Erase(const std::string &attr_name) { diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index e22c29037718a60ff7f24404d7749600e2edb80b..a5acfd70449e92663cb66ef90a141c087ff6ec88 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include #include @@ -27,14 +26,15 @@ namespace paddle { namespace framework { static std::shared_ptr GetNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, const VariableNameMap& var_map, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = var_map.at(prm); + auto& var_names = var_map.at(name); PADDLE_ENFORCE_EQ(var_names.size(), 1, - "op %s prm %s expects one associated var", op->Type(), prm); + "op %s name %s expects one associated var", op->Type(), + name); if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { return (*ngb_node_map)[var_names[0]]; } else { @@ -43,42 +43,42 @@ static std::shared_ptr GetNode( } static std::shared_ptr GetInputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Inputs(), ngb_node_map); + return GetNode(op, name, op->Inputs(), ngb_node_map); } static std::shared_ptr GetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Outputs(), ngb_node_map); + return GetNode(op, name, op->Outputs(), ngb_node_map); } static void SetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr node, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = op->Outputs().at(prm); + auto& var_names = op->Outputs().at(name); if (var_names.size() == 1) { (*ngb_node_map)[var_names[0]] = node; } else if (var_names.size() == 0) { (*ngb_node_map)[""] = node; } else { - PADDLE_THROW("prm %s has more than 1 var_names.", prm); + PADDLE_THROW("name %s has more than 1 var_names.", name); } } static bool HasOutput(const std::shared_ptr& op, - const std::string prm) { + const std::string name) { auto& outputs = op->Outputs(); - if (outputs.find(prm) == outputs.end()) return false; - return outputs.at(prm).size() > 0; + if (outputs.find(name) == outputs.end()) return false; + return outputs.at(name).size() > 0; } template @@ -118,4 +118,3 @@ void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h index 9ed6b9510942136a61faa5755fd8fa74286939a8..5ad7b8daeb6a782515e50fc87ca7188b46308390 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -53,4 +51,3 @@ class NgraphBridge { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 3fea753f0659019395c9b214e52a7912058c501c..253de4c61160e52202a0192215a93284f27e5896 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include @@ -58,16 +57,16 @@ typedef enum { /* nGraph support state on ops */ } op_state; // perform graph build through bridge and execute computation -class NgraphOperator { +class NgraphEngine { public: - explicit NgraphOperator(const Scope& scope, const platform::Place& place, - const std::vector>& ops, - const std::unordered_map< - std::string, ngraph::element::Type>& var_type_map, - const std::unordered_set& persist, - const std::unordered_set& fetches, - const std::unordered_set& post_op_inputs, - op_state ng_op_state) + explicit NgraphEngine(const Scope& scope, const platform::Place& place, + const std::vector>& ops, + const std::unordered_map< + std::string, ngraph::element::Type>& var_type_map, + const std::unordered_set& persist, + const std::unordered_set& fetches, + const std::unordered_set& post_op_inputs, + op_state ng_op_state) : scope_(scope), place_(place), fused_ops_(ops), @@ -132,7 +131,7 @@ class NgraphOperator { }; std::vector>::iterator>> -FusedOperator::FusedOpIntervals( +NgraphOperator::NgraphOpIntervals( std::vector>* ops) { std::vector>::iterator>> intervals; @@ -185,7 +184,7 @@ FusedOperator::FusedOpIntervals( return intervals; } -FusedOperator::FusedOperator( +NgraphOperator::NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -215,7 +214,7 @@ FusedOperator::FusedOperator( Process(); } -void FusedOperator::Process() { +void NgraphOperator::Process() { auto& bdesc = pdesc_.Block(block_); for (auto& var : bdesc.AllVars()) { if (!(var->GetType() == proto::VarType::SELECTED_ROWS || @@ -251,8 +250,8 @@ void FusedOperator::Process() { } } -void FusedOperator::RunImpl(const Scope& scope, - const platform::Place& place) const { +void NgraphOperator::RunImpl(const Scope& scope, + const platform::Place& place) const { op_state ng_op_state = PARTIAL_TEST; auto& bdesc = pdesc_.Block(block_); for (auto* op : bdesc.AllOps()) { @@ -266,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope, ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; } - NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_, - persistables_, fetches_, post_op_inputs_, - ng_op_state); - ngraph_op.Run(scope, place); + NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_, + persistables_, fetches_, post_op_inputs_, + ng_op_state); + ngraph_engine.Run(scope, place); } std::unordered_map> - NgraphOperator::func_cache_ = {}; + NgraphEngine::func_cache_ = {}; -std::shared_ptr NgraphOperator::backend_ = +std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); -void NgraphOperator::GetNgInputShape(std::shared_ptr op) { +void NgraphEngine::GetNgInputShape(std::shared_ptr op) { op->RuntimeInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { @@ -301,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr op) { } } -void NgraphOperator::BuildNgNodes() { +void NgraphEngine::BuildNgNodes() { for (auto& var_name : var_out_) { if (var_node_map_->find(var_name) == var_node_map_->end()) { auto* var = scope_.FindVar(var_name); @@ -323,7 +322,7 @@ void NgraphOperator::BuildNgNodes() { } } -void NgraphOperator::BuildNgIO() { +void NgraphEngine::BuildNgIO() { std::unordered_set inputs; std::unordered_set outputs; @@ -395,7 +394,7 @@ void NgraphOperator::BuildNgIO() { } } -void NgraphOperator::BuildNgFunction() { +void NgraphEngine::BuildNgFunction() { BuildNgNodes(); ngraph_function_ = nullptr; ngraph::NodeVector func_outputs; @@ -416,7 +415,7 @@ void NgraphOperator::BuildNgFunction() { std::make_shared(func_outputs, func_inputs); } -std::shared_ptr NgraphOperator::GetCacheKey() { +std::shared_ptr NgraphEngine::GetCacheKey() { auto cache_key = std::make_shared(""); *cache_key += std::to_string(fused_ops_.size()); for (auto& op : fused_ops_) { @@ -444,7 +443,7 @@ std::shared_ptr NgraphOperator::GetCacheKey() { return cache_key; } -void NgraphOperator::GetNgFunction() { +void NgraphEngine::GetNgFunction() { bool cache_on = true; if (cache_on) { std::string cache_key_val = *GetCacheKey(); @@ -459,8 +458,7 @@ void NgraphOperator::GetNgFunction() { } } -void NgraphOperator::Run(const Scope& scope, - const platform::Place& place) const { +void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const { std::vector> t_in; std::vector> t_out; @@ -545,7 +543,6 @@ void NgraphOperator::Run(const Scope& scope, } backend_->call(ngraph_function_, t_out, t_in); -} // NgraphOperator::RunImpl +} // NgraphEngine::RunImpl } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h index 3ca023e11111c5b447b2cabbfb8bb29877297f65..ede80f44bea208b66acc3b3f4bc0f4adee4fb860 100644 --- a/paddle/fluid/framework/ngraph_operator.h +++ b/paddle/fluid/framework/ngraph_operator.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -34,14 +32,14 @@ limitations under the License. */ namespace paddle { namespace framework { -class FusedOperator : public OperatorBase { +class NgraphOperator : public OperatorBase { public: static std::vector< std::vector>::iterator>> - FusedOpIntervals( + NgraphOpIntervals( std::vector>* ops); - explicit FusedOperator( + explicit NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -64,4 +62,3 @@ class FusedOperator : public OperatorBase { }; } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e8ecd90502933a049cc8f886212579fc061d44ff..dde642764fa5dfce11edcef51ad1be11be331fbc 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -237,6 +237,23 @@ void OpDesc::SetOutput(const std::string ¶m_name, this->outputs_[param_name] = args; } +bool OpDesc::HasProtoAttr(const std::string &name) const { + auto &op_info = OpInfoMap::Instance(); + if (op_info.Has(desc_.type())) { + auto op_info_ptr = op_info.Get(desc_.type()); + if (op_info_ptr.HasOpProtoAndChecker()) { + const proto::OpProto &proto = op_info_ptr.Proto(); + for (int i = 0; i != proto.attrs_size(); ++i) { + const proto::OpProto::Attr &attr = proto.attrs(i); + if (attr.name() == name) { + return true; + } + } + } + } + return false; +} + proto::AttrType OpDesc::GetAttrType(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 30c8a26c3d2f0068674aa70b4ff875a2f73c1dca..e8debec7f13706b7fc5a4882d237ee2257e53b7e 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -65,6 +65,8 @@ class OpDesc { return attrs_.find(name) != attrs_.end(); } + bool HasProtoAttr(const std::string &name) const; + proto::AttrType GetAttrType(const std::string &name) const; std::vector AttrNames() const; diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 36673e48c2047bca54f604b082dfec123f1e2c82..6d39bb3c524b4725dfebd6ef07594b0b45c65463 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -319,7 +319,7 @@ struct OpKernelRegistrarFunctorExGet().value()); } if (t != nullptr) { + PADDLE_ENFORCE(t->IsInitialized(), "Input %s is not initialized: %s", + ipt_name, DebugString()); int tmp = static_cast(ToDataType(t->type())); PADDLE_ENFORCE( tmp == data_type || data_type == -1, diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8d35361eb6506f4f4eca21f7ff2ea69f6dd6c1b9..81d1024cb65f6209c028e922fa15cdaee18eda94 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -27,17 +27,41 @@ limitations under the License. */ #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" +#include "paddle/fluid/framework/details/reference_count_pass_helper.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/platform/profiler.h" +#ifdef WITH_GPERFTOOLS +#include "gperftools/profiler.h" +#endif +DEFINE_string(pe_profile_fname, "", + "Profiler filename for PE, which generated by gperftools." + "Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable."); + namespace paddle { namespace framework { +static std::once_flag gProfileOnce; +#ifdef WITH_GPERFTOOLS +static bool gProfileStarted = false; +#endif class ParallelExecutorPrivate { public: explicit ParallelExecutorPrivate(const std::vector &places) - : places_(places) {} + : places_(places) { + if (!FLAGS_pe_profile_fname.empty()) { + std::call_once(gProfileOnce, [] { +#ifdef WITH_GPERFTOOLS + ProfilerStart(FLAGS_pe_profile_fname.c_str()); + gProfileStarted = true; +#else + LOG(WARNING) << "Paddle is not compiled with gperftools. " + "FLAGS_pe_profile_fname will be ignored"; +#endif + }); + } + } ~ParallelExecutorPrivate() { if (own_local_scope_) { @@ -50,6 +74,26 @@ class ParallelExecutorPrivate { } } } + + std::unique_ptr PrepareGCAndRefCnts( + std::unique_ptr graph, size_t max_memory_size); + + inline bool HasGarbageCollectors() const { return !gcs_.empty(); } + + void ResetRuntimeReferenceCount(const std::vector &fetch_tensors, + const std::string &fetched_var_name) { + for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) { + for (auto &pair : global_ref_cnts_[i]) { + runtime_ref_cnts_[i][pair.first] = pair.second; + } + + for (auto &fetch_name : fetch_tensors) { + runtime_ref_cnts_[i].erase(fetch_name); + } + runtime_ref_cnts_[i].erase(fetched_var_name); + } + } + std::vector places_; std::vector local_scopes_; Scope *global_scope_; // not owned @@ -61,8 +105,76 @@ class ParallelExecutorPrivate { bool own_local_scope_; bool use_cuda_; bool use_all_reduce_; + + // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and + // then keeps unchanged + // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_ + std::vector global_ref_cnts_; + std::vector runtime_ref_cnts_; + details::GarbageCollectorMap gcs_; }; +std::unique_ptr ParallelExecutorPrivate::PrepareGCAndRefCnts( + std::unique_ptr graph, size_t max_memory_size) { + for (size_t i = 0; i < places_.size(); ++i) { + auto &place = places_[i]; + if (gcs_.count(place) > 0) { + continue; + } + std::unique_ptr gc; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place)) { + if (IsFastEagerDeletionModeEnabled()) { + gc.reset(new UnsafeFastGPUGarbageCollector( + boost::get(place), max_memory_size)); + } else { + gc.reset(new StreamGarbageCollector( + boost::get(place), max_memory_size)); + } + VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; + } else { +#endif + if (platform::is_cpu_place(place)) { + gc.reset(new CPUGarbageCollector(boost::get(place), + max_memory_size)); + VLOG(10) << "Created GarbageCollector at " << place; + } else { + PADDLE_THROW("Unsupported place for garbage collection"); + } +#ifdef PADDLE_WITH_CUDA + } +#endif + + gcs_.emplace(place, std::move(gc)); + } + + if (!gcs_.empty()) { + std::vector last_live_ops_of_vars; + + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, + &global_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + graph = ref_cnt_pass->Apply(std::move(graph)); + VLOG(10) << "ReferenceCountPass Applied"; + + auto eager_deletion_pass = + ir::PassRegistry::Instance().Get("eager_deletion_pass"); + eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount, + &runtime_ref_cnts_); + eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); + graph = eager_deletion_pass->Apply(std::move(graph)); + VLOG(10) << "EagerDeletionPass Applied"; + } + + return graph; +} + std::vector &ParallelExecutor::GetLocalScopes() { return member_->local_scopes_; } @@ -161,34 +273,6 @@ ParallelExecutor::ParallelExecutor( member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); graphs.push_back(std::move(graph)); } - - auto max_memory_size = GetEagerDeletionThreshold(); - // FIXME(Yancey1989): need to fix on parallel graph mode - if (max_memory_size >= 0 && - exec_strategy.type_ != ExecutionStrategy::kParallelGraph) { - for (auto &place : member_->places_) { - if (!platform::is_gpu_place(place)) continue; - auto gpu_place = boost::get(place); - if (gcs_[gpu_place.device] == nullptr) { - ref_cnts_[gpu_place.device].reset(new details::ReferenceCountMap()); - cur_ref_cnts_[gpu_place.device].reset( - new details::AtomicReferenceCountMap()); - gcs_[gpu_place.device].reset( - new StreamGarbageCollector(gpu_place, max_memory_size)); - } - } - if (!gcs_.empty()) { - for (size_t i = 0; i < graphs.size(); ++i) { - auto ref_cnt_pass = - ir::PassRegistry::Instance().Get("reference_count_pass"); - ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); - ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); - ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); - graphs[i] = ref_cnt_pass->Apply(std::move(graphs[i])); - graphs[i]->SetNotOwned("garbage_collector", &gcs_); - } - } - } #else std::unique_ptr graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, @@ -317,19 +401,16 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { - platform::RecordBlock b(0); -#ifdef PADDLE_WITH_CUDA - if (!gcs_.empty()) { - ResetReferenceCount(); - for (auto &pair : cur_ref_cnts_) { - auto &name_map = *(pair.second); - for (auto &fetch_name : fetch_tensors) { - name_map.erase(fetch_name); - } - name_map.erase(fetched_var_name); - } +#ifdef WITH_GPERFTOOLS + if (gProfileStarted) { + ProfilerFlush(); } #endif + + platform::RecordBlock b(0); + if (member_->HasGarbageCollectors()) { + member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name); + } auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; @@ -373,13 +454,11 @@ ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } - // member_ must be destructed before gcs_ since the destructor of - // ReferenceCountOpHandle use raw pointers of gcs_ inside. - member_.reset(); + delete member_; } } // namespace framework } // namespace paddle -#ifdef PADDLE_WITH_CUDA + USE_PASS(reference_count_pass); -#endif +USE_PASS(eager_deletion_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ef09b98b2aa91a9d729b94d15dbb676dde4092b6..1fc17a0d64d50eb70ce66cacd4752a5b96d5e894 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include #include #include #include @@ -29,10 +28,6 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/framework/details/reference_count_pass.h" -#endif - namespace paddle { namespace framework { @@ -75,24 +70,7 @@ class ParallelExecutor { private: void BCastParamsToDevices(const std::unordered_set &vars) const; - std::unique_ptr member_; - -#ifdef PADDLE_WITH_CUDA - // ref_cnts_ is only initialized when ParallelExecutor constructs, and then - // keeps unchanged - // Before each iteration, cur_ref_cnts_ is reset to ref_cnts_ - details::DeviceReferenceCountMap ref_cnts_; - details::AtomicDeviceReferenceCountMap cur_ref_cnts_; - details::DeviceGarbageCollectorMap gcs_; - - void ResetReferenceCount() { - for (auto &pair1 : ref_cnts_) { - for (auto &pair2 : *(pair1.second)) { - (*(cur_ref_cnts_[pair1.first]))[pair2.first] = pair2.second; - } - } - } -#endif + ParallelExecutorPrivate *member_; }; } // namespace framework diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 0d261dd7ccc323abddd2c3ef13f1874661a8ca75..6fa5e99f9f3a7e871f1a742a30803853988ea6eb 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -38,6 +38,10 @@ DEFINE_double( "Memory size threshold (GB) when the garbage collector clear tensors." "Disabled when this value is less than 0"); +DEFINE_bool(fast_eager_deletion_mode, false, + "Fast eager deletion mode. If enabled, memory would release " + "immediately without waiting GPU kernel ends."); + // When in inference scenario, the scopes will not be written by two threads in // a mean time, but a scope may be read by multiple threads concurrently, and // the mutex will cause serious performance issue. @@ -58,6 +62,8 @@ int64_t GetEagerDeletionThreshold() { (static_cast(1) << 30)); } +bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; } + Scope::~Scope() { DropKids(); } Scope& Scope::NewScope() const { diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 1901ffbe57e0d85193c3a218f06eba06a0f287a5..aded1f771cedbf2442ad36d7fab3e6e6caffdc24 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -27,6 +27,7 @@ namespace paddle { namespace framework { int64_t GetEagerDeletionThreshold(); +bool IsFastEagerDeletionModeEnabled(); class Scope; diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 71e8badd4b6b08e7d380fd45d93a33176172081d..153222506af02911b792907d392a42e5499c5e7a 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -158,6 +158,10 @@ class Tensor { const std::shared_ptr& Holder() const { return holder_; } size_t offset() const { return offset_; } + std::shared_ptr MoveMemoryHolder() { + return std::move(holder_); + } + private: /*! holds the memory block if allocated. */ std::shared_ptr holder_; diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..373d292b443b7651b785a52a6986b0a0be58ad61 --- /dev/null +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -0,0 +1,3 @@ +cc_library(layer SRCS layer.cc DEPS proto_desc operator) +cc_library(tracer SRCS tracer.cc DEPS proto_desc) +cc_library(engine SRCS engine.cc) diff --git a/paddle/fluid/imperative/engine.cc b/paddle/fluid/imperative/engine.cc new file mode 100644 index 0000000000000000000000000000000000000000..de7ab0e5918281579728ef48d1517be2cd530af7 --- /dev/null +++ b/paddle/fluid/imperative/engine.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/engine.h" + +#include // NOLINT +#include + +#include "glog/logging.h" + +namespace paddle { +namespace imperative { + +static std::once_flag init_engine; +static Engine* engine; + +class DummyEngine : public Engine { + public: + void Enqueue(Runnable* runnable) override { + queued_runnables_.push_back(runnable); + } + + size_t Size() const override { return queued_runnables_.size(); } + + void Sync() override { + for (Runnable* l : queued_runnables_) { + LOG(INFO) << "running " << reinterpret_cast(l); + } + queued_runnables_.clear(); + } + + private: + std::vector queued_runnables_; +}; + +Engine* GetEngine() { + std::call_once(init_engine, []() { engine = new DummyEngine(); }); + return engine; +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/engine.h b/paddle/fluid/imperative/engine.h new file mode 100644 index 0000000000000000000000000000000000000000..a1dfa5bda38d0c419aa4ccbea77b32eb7e0d5b23 --- /dev/null +++ b/paddle/fluid/imperative/engine.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace paddle { +namespace imperative { + +struct Runnable {}; + +class Engine { + public: + virtual ~Engine() {} + + virtual void Enqueue(Runnable* runnable) = 0; + + virtual size_t Size() const = 0; + + virtual void Sync() = 0; +}; + +Engine* GetEngine(); + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc new file mode 100644 index 0000000000000000000000000000000000000000..612503768079472ba233ee3fcd43a47fdba9a0cc --- /dev/null +++ b/paddle/fluid/imperative/layer.cc @@ -0,0 +1,221 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/layer.h" +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/string/printf.h" + +namespace paddle { +namespace imperative { + +using framework::Variable; + +void AddTo(Variable* src, Variable* dst) { + framework::LoDTensor* dst_tensor = dst->GetMutable(); + framework::LoDTensor* src_tensor = src->GetMutable(); + PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld", + dst_tensor->numel(), src_tensor->numel()); + float* dst_data = dst_tensor->mutable_data(platform::CPUPlace()); + const float* src_data = src_tensor->data(); + for (size_t i = 0; i < src_tensor->numel(); ++i) { + dst_data[i] += src_data[i]; + } +} + +class Autograd { + public: + explicit Autograd(framework::Scope* scope) : scope_(scope) {} + + void RunBackward(VarBase* var) { + PADDLE_ENFORCE(var->pre_op_->op_desc_); + // TODO(panyx0718): Only create for vars that "require_grad" + (*var->pre_op_->output_vars_)[var->pre_op_out_idx_]->grads_ = var->grads_; + + std::deque ready; + ready.push_back(var->pre_op_); + + std::map dep_counts = ComputeDepCounts(var->pre_op_); + + while (!ready.empty()) { + OpBase* ready_op = ready.front(); + ready.pop_front(); + std::vector input_grads = ready_op->ApplyGrad(scope_); + + for (size_t i = 0; i < input_grads.size(); ++i) { + if (!input_grads[i]) continue; + OpBase* pre_op = ready_op->pre_ops_->at(i); + if (!pre_op) continue; + + dep_counts[pre_op] -= 1; + PADDLE_ENFORCE(dep_counts[pre_op] >= 0); + bool pre_op_ready = dep_counts[pre_op] == 0; + if (pre_op_ready) { + ready.push_back(pre_op); + } + } + } + } + + private: + std::map ComputeDepCounts(OpBase* op) { + std::map ret; + + std::deque queue; + queue.push_back(op); + std::unordered_set visited; + visited.insert(op); + while (!queue.empty()) { + OpBase* candidate = queue.front(); + queue.pop_front(); + for (OpBase* pre_op : *(candidate->pre_ops_)) { + if (!pre_op) continue; + if (visited.find(pre_op) == visited.end()) { + visited.insert(pre_op); + queue.push_back(pre_op); + } + ret[pre_op] += 1; + } + } + + return ret; + } + + framework::Scope* scope_; +}; + +framework::Variable* CreateVariable(const std::string& name, + const framework::DDim& dim, float val, + framework::Scope* scope, + bool random_name = true) { + std::string varname = name; + if (random_name) { + std::mt19937 rng; + rng.seed(std::random_device()()); + std::uniform_int_distribution dist6( + 1, std::numeric_limits::max()); + int id = dist6(rng); + varname = string::Sprintf("%s@%d", varname, id); + } + + VLOG(3) << "creating var " << varname; + framework::Variable* var = scope->Var(varname); + framework::LoDTensor* tensor = var->GetMutable(); + + float* data = tensor->mutable_data(dim, platform::CPUPlace()); + std::fill(data, data + tensor->numel(), val); + return var; +} + +framework::LoDTensor& VarBase::Grad() { + VLOG(3) << "get var grad " << var_desc_->Name(); + return *grads_->GetMutable(); +} + +void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { + VLOG(3) << "apply var grad " << var_desc_->Name() << " " + << grad->Get().data()[0]; + if (!grads_) { + grads_ = + CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()), + var_->Get().dims(), 0.0, scope); + } + AddTo(grad, grads_); + VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " " + << grads_->Get().data()[0]; +} + +std::vector OpBase::ApplyGrad(framework::Scope* scope) { + VLOG(3) << "op grad " << grad_op_desc_->Type(); + + for (const std::string& grad_invar : grad_op_desc_->InputArgumentNames()) { + if (grad_to_var_->find(grad_invar) == grad_to_var_->end()) { + // grad op inputs can be forward inputs, so not in grad_to_var. + continue; + } + VLOG(3) << "op grad in var " << grad_invar; + block_->FindRecursiveOrCreateVar(grad_invar); + framework::Variable* var = scope->Var(grad_invar); + const std::string& invar = grad_to_var_->at(grad_invar); + for (VarBase* varbase : *output_vars_) { + // Use the accumulated grads_ by sharing the input with grads_. + if (varbase->var_desc_->Name() == invar) { + var->GetMutable()->ShareDataWith( + varbase->grads_->Get()); + break; + } + } + } + + for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { + VLOG(3) << "grad outvar " << outvar; + block_->FindRecursiveOrCreateVar(outvar); + framework::Variable* var = scope->Var(outvar); + if (!var->IsInitialized()) { + framework::VarDesc* var_desc = block_->FindVar(outvar); + if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { + var->GetMutable(); + } else { + LOG(ERROR) << "tracer doesn't support yet"; + } + } + } + grad_op_desc_->InferShape(*block_); + grad_op_desc_->InferVarType(block_); + std::unique_ptr opbase = + framework::OpRegistry::CreateOp(*grad_op_desc_); + + opbase->Run(*scope, platform::CPUPlace()); + + // `ret` matches exactly with `input_vars_` of forward op. + std::vector ret; + for (size_t i = 0; i < input_vars_->size(); ++i) { + bool found = false; + for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { + Variable* var = scope->FindVar(outvar); + VarBase* origin_var = (*input_vars_)[i]; + std::string orig_var = grad_to_var_->at(outvar); + PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var); + VLOG(3) << "apply grad " << outvar << " with origin " << orig_var; + origin_var->ApplyGrad(scope, var); + found = true; + ret.push_back(var); + // TODO(panyx0718): There might be another outvar with the same name. + // In that case, it doesn't matter the first one or the second one is + // used. + break; + } + if (!found) { + ret.push_back(nullptr); + } + } + return ret; +} + +void VarBase::RunBackward(framework::Scope* scope) { + grads_ = CreateVariable(framework::GradVarName(var_desc_->Name()), + var_->Get().dims(), 1.0, scope, + false); + if (!pre_op_) return; + Autograd(scope).RunBackward(this); +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h new file mode 100644 index 0000000000000000000000000000000000000000..85a71ca83d21ed2595ddbe684300a46c05fed3af --- /dev/null +++ b/paddle/fluid/imperative/layer.h @@ -0,0 +1,102 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace imperative { + +class OpBase; + +class VarBase { + public: + VarBase() + : pre_op_(nullptr), + pre_op_out_idx_(-1), + var_desc_(nullptr), + var_(nullptr), + grads_(nullptr) {} + + virtual ~VarBase() {} + + void ApplyGrad(framework::Scope* scope, framework::Variable* grad); + + void RunBackward(framework::Scope* scope); + + framework::LoDTensor& Grad(); + + OpBase* pre_op_; + int pre_op_out_idx_; + + framework::VarDesc* var_desc_; + framework::Variable* var_; + framework::Variable* grads_; +}; + +class OpBase { + public: + OpBase() + : input_vars_(new std::vector()), + output_vars_(new std::vector()), + pre_ops_(new std::vector()), + pre_ops_out_idx_(new std::vector()), + op_desc_(nullptr), + grad_op_desc_(nullptr) {} + + virtual ~OpBase() { + delete input_vars_; + delete output_vars_; + + delete pre_ops_; + delete pre_ops_out_idx_; + + if (grad_op_desc_) delete grad_op_desc_; + if (grad_to_var_) delete grad_to_var_; + } + + std::vector ApplyGrad(framework::Scope* scope); + + std::vector* input_vars_; + std::vector* output_vars_; + std::vector* pre_ops_; + std::vector* pre_ops_out_idx_; + framework::OpDesc* op_desc_; + + framework::OpDesc* grad_op_desc_; + std::unordered_map* grad_to_var_; + framework::BlockDesc* block_; +}; + +class Layer { + public: + virtual ~Layer() {} + + virtual std::vector Forward(const std::vector& inputs) { + std::vector vars; + return vars; + } + + virtual void Backward() { LOG(ERROR) << "To support customize"; } +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc new file mode 100644 index 0000000000000000000000000000000000000000..f64f9e72c4a23528948183b909d65e90783a4463 --- /dev/null +++ b/paddle/fluid/imperative/tracer.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/imperative/tracer.h" + +namespace paddle { +namespace imperative {} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h new file mode 100644 index 0000000000000000000000000000000000000000..433d07c0e5aa0986ab1e9fe349ef865d2851c0c0 --- /dev/null +++ b/paddle/fluid/imperative/tracer.h @@ -0,0 +1,128 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/imperative/engine.h" +#include "paddle/fluid/imperative/layer.h" + +namespace paddle { +namespace imperative { + +void CreateGradOp(const framework::OpDesc& op_desc, + const std::unordered_set& no_grad_set, + const std::vector& grad_sub_block, + framework::OpDesc** grad_op_desc, + std::unordered_map* grad_to_var) { + std::vector> grad_op_descs = + framework::OpInfoMap::Instance() + .Get(op_desc.Type()) + .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); + PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now."); + // TODO(panyx0718): Leak? + *grad_op_desc = grad_op_descs[0].release(); +} + +class Tracer { + public: + explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { + root_scope_ = new framework::Scope(); + scopes_[root_block_] = root_scope_; + } + + virtual ~Tracer() { delete root_scope_; } + + void Trace(OpBase* op, const std::vector& inputs, + const std::vector& outputs, + framework::BlockDesc* block) { + framework::Scope* scope = GetScope(block); + framework::OpDesc* op_desc = op->op_desc_; + VLOG(3) << "tracer tracing " << op_desc->Type(); + op_desc->InferShape(*block); + op_desc->InferVarType(block); + std::unique_ptr op_base = + framework::OpRegistry::CreateOp(*op_desc); + + *op->input_vars_ = inputs; + for (VarBase* input : inputs) { + const std::string vname = input->var_desc_->Name(); + framework::Variable* var = scope->Var(vname); + input->var_ = var; + if (!var->IsInitialized()) { + framework::VarDesc* var_desc = block->FindVar(vname); + if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { + var->GetMutable(); + } else { + LOG(ERROR) << "tracer doesn't support yet"; + } + } + if (input->pre_op_) { + op->pre_ops_->push_back(input->pre_op_); + op->pre_ops_out_idx_->push_back(input->pre_op_out_idx_); + } else { + op->pre_ops_->push_back(nullptr); + } + } + + *op->output_vars_ = outputs; + for (size_t i = 0; i < outputs.size(); ++i) { + const std::string vname = outputs[i]->var_desc_->Name(); + framework::Variable* var = scope->Var(vname); + if (!var->IsInitialized()) { + framework::VarDesc* var_desc = block->FindVar(vname); + if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { + var->GetMutable(); + } else { + LOG(ERROR) << "tracer doesn't support yet"; + } + } + outputs[i]->var_ = var; + outputs[i]->pre_op_ = op; + outputs[i]->pre_op_out_idx_ = i; + } + op_base->Run(*scope, platform::CPUPlace()); + framework::OpDesc* grad_op_desc; + auto grad_to_var = new std::unordered_map(); + CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); + op->grad_op_desc_ = grad_op_desc; + op->grad_to_var_ = grad_to_var; + op->block_ = block; + } + + framework::Scope* GetScope(framework::BlockDesc* block) { + if (scopes_.find(block) != scopes_.end()) { + return scopes_.at(block); + } + framework::BlockDesc* parent_block = block->ParentBlock(); + PADDLE_ENFORCE(scopes_.find(parent_block) != scopes_.end()); + framework::Scope* scope = &scopes_[parent_block]->NewScope(); + scopes_[block] = scope; + return scope; + } + + private: + std::map scopes_; + framework::BlockDesc* root_block_; + framework::Scope* root_scope_; +}; + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index b8a045c18fab54581b4d2b902be373f55ad09e8a..c6e923c00484f01f17550ae2926dabcadc0c3ac6 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -44,9 +44,10 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { - auto program = - LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr(), place, argument->model_from_memory()); + auto program = LoadModel( + argument->model_program_path(), argument->model_params_path(), + argument->scope_ptr(), place, + argument->model_from_memory_valid() && argument->model_from_memory()); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( diff --git a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc index 343fd3f7c5aed6931fc215445c17d3ed7074368e..1d0d83d1f368f879878a4df8b2eefae0bc89423d 100644 --- a/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/pool2d_op.cc @@ -109,8 +109,12 @@ class Pool2dOpConverter : public OpConverter { } if (pool_type == "max") { - nvinfer1::DimsHW pre_pad(paddings[0], paddings[1]); - nvinfer1::DimsHW post_pad(paddings[0], paddings[1]); + // Under ceil mode, the pre_pad and post_pad are used to + // record the the padding size. In some ceil mode cases, + // we do not need padding, so we initialize the two vars to 0. + + nvinfer1::DimsHW pre_pad(0, 0); + nvinfer1::DimsHW post_pad(0, 0); if (ceil_mode) { // If ceil mode is true, we will pad the appropriate size to the input. DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad, diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a07626a10315a6206f8c1ebc9a19df90663a88ee..8a4bc04b67879918c6ac8d1b40dae68a107034d4 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,4 +1,4 @@ -set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor benchmark) if(WITH_GPU AND TENSORRT_FOUND) set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index d572ea0177c1e398229a02718ca31cc78a7059ef..8209a049f4614fe31c22c4e83c1968411b749b49 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -30,8 +30,10 @@ #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/tests/api/config_printer.h" #include "paddle/fluid/inference/tests/test_helper.h" +#include "paddle/fluid/inference/utils/benchmark.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_string(model_name, "", "model name"); DEFINE_string(infer_model, "", "model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_int32(batch_size, 1, "batch size."); @@ -40,6 +42,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DEFINE_bool(record_benchmark, false, + "Record benchmark after profiling the model"); DECLARE_bool(profile); DECLARE_int32(paddle_num_threads); @@ -192,8 +196,16 @@ void TestOneThreadPrediction( predictor->Run(inputs[j], outputs, batch_size); } } - PrintTime(batch_size, num_times, 1, 0, run_timer.toc() / num_times, - inputs.size()); + + double latency = run_timer.toc() / num_times; + PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); + if (FLAGS_record_benchmark) { + Benchmark benchmark; + benchmark.SetName(FLAGS_model_name); + benchmark.SetBatchSize(batch_size); + benchmark.SetLatency(latency); + benchmark.PersistToFile("benchmark_record.txt"); + } } } diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index ef612ce6148329c33f194842945bb5438afcf645..9eb3fb5da1065f14d9ad1c3520f6415fbadfdca1 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -135,6 +135,9 @@ TEST(TensorRT_resnext50, compare) { TEST(TensorRT_resnext50, profile) { std::string model_dir = FLAGS_infer_model + "/resnext50"; + // Set FLAGS_record_benchmark to true to record benchmark to file. + // FLAGS_record_benchmark=true; + FLAGS_model_name = "resnext50"; profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); } diff --git a/paddle/fluid/inference/utils/benchmark.cc b/paddle/fluid/inference/utils/benchmark.cc index d03aa11b75ee58524746212e43a5796773f47932..0bd526bcac2d9ceda95730dc3c5210aed8ccfb5c 100644 --- a/paddle/fluid/inference/utils/benchmark.cc +++ b/paddle/fluid/inference/utils/benchmark.cc @@ -30,7 +30,7 @@ std::string Benchmark::SerializeToString() const { ss << '\n'; ss << name_ << "\t"; - ss << batch_size_ << "\t"; + ss << batch_size_ << "\t\t"; ss << num_threads_ << "\t"; ss << latency_ << "\t"; ss << 1000.0 / latency_; diff --git a/paddle/fluid/inference/utils/visualizer.cc b/paddle/fluid/inference/utils/visualizer.cc index 040b6476fb4febc5ca1912c8db72dc63c3bced08..7c0dd64dea88e51b24c4bc04818d633ee0d2f722 100644 --- a/paddle/fluid/inference/utils/visualizer.cc +++ b/paddle/fluid/inference/utils/visualizer.cc @@ -26,9 +26,6 @@ DEFINE_string(model_dir, "", "model directory"); DEFINE_string(model_program_path, "", "model program path"); DEFINE_string(model_params_path, "", "model params path"); -USE_PASS(graph_viz_pass); -USE_PASS(graph_to_program_pass); - using paddle::inference::analysis::Argument; namespace paddle { @@ -40,7 +37,6 @@ void Visualizer::SetArgument(Argument *argument) { argument_ = argument; } bool Visualizer::Run() { paddle::framework::InitDevices(false); paddle::inference::analysis::Analyzer().Run(argument_); - return true; } @@ -77,7 +73,7 @@ int main(int argc, char *argv[]) { // Only 1 pass, default filename is 0_ir_origin.dot // For more details, looking for paddle::inference::analysis::IRPassManager - argument.SetIrAnalysisPasses({"graph_viz_pass"}); + argument.SetIrAnalysisPasses({"infer_clean_graph_pass", "graph_viz_pass"}); std::unique_ptr scope{ new paddle::framework::Scope()}; @@ -90,3 +86,7 @@ int main(int argc, char *argv[]) { return 0; } + +USE_PASS(infer_clean_graph_pass); +USE_PASS(graph_viz_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 87d549678a0e6c183aac89539cf1f6331729de2c..c7df3ea58a91579e35ff0d486516271a6daf054f 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -301,23 +301,22 @@ template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - auto temp = - ((x * static_cast(M_SQRT1_2)).erf()).template cast().eval(); + auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } }; template struct GeluGradFunctor : BaseActivationFunctor { - bool Inplace() const { return IsInplace("gelu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp = (static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - ((-static_cast(0.5) * x.square()).exp())) - .template cast() - .eval(); - dx.device(d) = dout * (out / x + temp); + auto first = static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); } }; diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.cu b/paddle/fluid/operators/bilinear_tensor_product_op.cu index 9426ffbe174c7daf9f24525f5f7ca12d986042f4..c2b4f69e6854522b91dfd9fb5f738c0e5ffc77b1 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.cu +++ b/paddle/fluid/operators/bilinear_tensor_product_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/fluid/operators/bilinear_tensor_product_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/bpr_loss_op.cc b/paddle/fluid/operators/bpr_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9258d7c7e83122149c7cbc42e4a4bdd84903ce67 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.cc @@ -0,0 +1,145 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/bpr_loss_op.h" + +namespace paddle { +namespace operators { + +class BprLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, label_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + + auto y_dims = x_dims; + y_dims[rank - 1] = 1; + ctx->SetOutputDim("Y", y_dims); + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of Seq-bpr + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(dy_dims.size(), rank, + "Input(Y@Grad) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(label_dims.size(), rank, + "Input(Label) and Input(X) should have the same rank."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(label_dims, 0, rank - 1), + "The Input(X) and Input(Label) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), + framework::slice_ddim(dy_dims, 0, rank - 1), + "The Input(X) and Input(Y@Grad) should have the same " + "shape except the last dimension."); + PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1, + "The last dimension of Input(Y@Grad) should be 1."); + PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1, + " the last dimension of Input(Label) should be 1."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of cross_entropy + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + platform::CPUPlace()); + } +}; + +class BprLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a tensor whose last dimension " + "size is equal to the number of classes. This input is a " + "real number."); + AddInput( + "Label", + "(Tensor), the tensor which represents the ground truth. It has the " + "same shape with 'X' except the last dimension. the last dimension " + "size is 1."); + AddOutput("Y", + "(Tensor, default Tensor), a tensor whose shape is same " + "with 'X' except that the last dimension size is 1. It " + "represents the sequence bpr loss."); + AddComment(R"DOC( +Bayesian Personalized Ranking Loss Operator. + +This operator belongs to pairwise ranking loss. Label is the desired item. +The loss at a given point in one session is defined as: +$Y[i] = -\frac{1}{N_{i}} * \sum_{j=0}^{N_{i}}\log(\sigma(X[i, Label[i]]-X[i, j]))$ + +Learn more details by reading paper (https://arxiv.org/abs/1511.06939) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPUCtx = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp); +REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel, + ops::BprLossOpKernel); +REGISTER_OP_CPU_KERNEL(bpr_loss_grad, + ops::BprLossGradientOpKernel, + ops::BprLossGradientOpKernel); diff --git a/paddle/fluid/operators/bpr_loss_op.h b/paddle/fluid/operators/bpr_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e223be7af82146e7c69c7c5aab8f08d0fe0d1710 --- /dev/null +++ b/paddle/fluid/operators/bpr_loss_op.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +/*Todo: + *Find a way to adapt TolerableValue, using blas or eigen. + */ +template +struct TolerableValue { + HOSTDEVICE T operator()(const T& x) const { + PADDLE_ASSERT(std::is_floating_point::value); + const T kApproInf = 1e20; + if (x == INFINITY) return kApproInf; + if (x == -INFINITY) return -kApproInf; + return x; + } +}; + +template +class BprLossOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* label = ctx.Input("Label"); + auto* y = ctx.Output("Y"); + y->mutable_data(ctx.GetPlace()); + int rank = x->dims().size(); + + Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); + Tensor labels_2d = framework::ReshapeToMatrix(*label, rank - 1); + Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + + const framework::Tensor* logits = &x_2d; + const framework::Tensor* labels = &labels_2d; + framework::Tensor* out = &y_2d; + + const int step_size = logits->dims()[0]; + const int class_num = logits->dims()[1]; + const T* logits_data = logits->data(); + T* loss_data = out->data(); + + const int64_t* label_data = labels->data(); + for (int i = 0; i < step_size; ++i) { + int lbl_pos = label_data[i]; + PADDLE_ENFORCE_GE(lbl_pos, 0); + PADDLE_ENFORCE_LT(lbl_pos, class_num); + int index_pos = i * class_num + lbl_pos; + T sum = static_cast(0); + for (int j = 0; j < class_num; j++) { + if (j == lbl_pos) continue; + int index_neg = i * class_num + j; + sum += TolerableValue()(-std::log( + 1.0f + TolerableValue()(std::exp(logits_data[index_neg] - + logits_data[index_pos])))); + } + loss_data[i] = -sum / (class_num - 1); + } + } +}; + +template +class BprLossGradientOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* label = ctx.Input("Label"); + auto* dx = ctx.Output(framework::GradVarName("X")); + + const int step_size = x->dims()[0]; + const int num_classes = x->dims()[1]; + T* dx_data = dx->mutable_data(ctx.GetPlace()); + const T* dy_data = dy->data(); + const T* x_data = x->data(); + const int64_t* label_data = label->data(); + + for (size_t sample_id = 0; sample_id < step_size; sample_id++) { + for (size_t x_offset = sample_id * num_classes; + x_offset < (sample_id + 1) * num_classes; x_offset++) { + dx_data[x_offset] = static_cast(0); + } + auto p_index = sample_id * num_classes + label_data[sample_id]; + for (size_t ni = 0; ni < num_classes; ni++) { + if (label_data[sample_id] == ni) continue; + auto n_index = sample_id * num_classes + ni; + auto grad_ = -dy_data[sample_id] / + ((num_classes - 1) * + (1.0f + TolerableValue()(std::exp(x_data[p_index] - + x_data[n_index])))); + dx_data[p_index] += grad_; + dx_data[n_index] -= grad_; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/concat_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ad674056f0d753d79408a11eff1aca47a84998a --- /dev/null +++ b/paddle/fluid/operators/concat_mkldnn_op.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::Tensor; +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::concat; +using mkldnn::stream; +using platform::to_void_cast; + +static void EnforceLayouts(const std::vector inputs) { + for (auto* input : inputs) { + const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; + const bool is_format_defined = + input->format() != memory::format::format_undef; + PADDLE_ENFORCE(is_layout_correct && is_format_defined, + "Wrong layout/format set for Input tensor"); + } +} + +static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, + const mkldnn::engine& engine) { + constexpr auto data_type = mkldnn::memory::f32; + const auto dims = paddle::framework::vectorize2int(input.dims()); + const auto format = input.format(); + auto description = memory::desc(dims, data_type, format); + auto mem_prim_desc = memory::primitive_desc(description, engine); + return mem_prim_desc; +} + +static mkldnn::memory::format GetDstMemFormat( + const concat::primitive_desc& concat_pd) { + return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; +} + +static platform::CPUPlace GetCpuPlace( + const paddle::framework::ExecutionContext& ctx) { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(paddle::platform::is_cpu_place(place), + "It must use CPUPlace."); + return boost::get(place); +} + +static const mkldnn::engine& GetMKLDNNEngine( + const paddle::framework::ExecutionContext& ctx) { + auto& dev_ctx = ctx.template device_context(); + return dev_ctx.GetEngine(); +} + +template +class ConcatPrimitiveFactory { + public: + concat::primitive_desc CreateConcatPrimDescriptor( + const std::vector multi_input, Tensor* output, + int concat_axis, const mkldnn::engine& mkldnn_engine) { + CreateSourcesDescriptors(multi_input, mkldnn_engine); + auto dst_desc = CreateDstMemDescriptor(output); + return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); + } + + concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + CreateSourcePrimitiveAts(); + dst_mem = CreateDstMemory(concat_pd, output, place); + return concat(concat_pd, inputs, dst_mem.get()); + } + + private: + memory::desc CreateDstMemDescriptor(Tensor* output) { + auto dst_dims = paddle::framework::vectorize2int(output->dims()); + return memory::desc(dst_dims, platform::MKLDNNGetDataType(), + memory::format::any); + } + + mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + return memory(concat_pd.dst_primitive_desc(), + output->mutable_data(place)); + } + + void CreateSourcesDescriptors(const std::vector multi_input, + const mkldnn::engine& mkldnn_engine) { + for (size_t i = 0; i < multi_input.size(); i++) { + auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); + srcs_pd.push_back(mem_prim_desc); + srcs.push_back( + memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); + } + } + + void CreateSourcePrimitiveAts() { + inputs.reserve(srcs.size()); + for (size_t i = 0; i < srcs.size(); i++) { + inputs.push_back(srcs[i]); + } + } + + private: + std::vector srcs_pd; + std::vector srcs; + std::vector inputs; + boost::optional dst_mem; // TODO(mgallus): change to std::optional +}; // upon introduction of C++17 to paddle + +template +class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto place = GetCpuPlace(ctx); + const auto& mkldnn_engine = GetMKLDNNEngine(ctx); + + auto multi_input = ctx.MultiInput("X"); + EnforceLayouts(multi_input); + Tensor* output = ctx.Output("Out"); + int64_t concat_axis = static_cast(ctx.Attr("axis")); + + ConcatPrimitiveFactory prim_creator; + auto concat_pd = prim_creator.CreateConcatPrimDescriptor( + multi_input, output, static_cast(concat_axis), mkldnn_engine); + auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); + stream(stream::kind::eager).submit({concat}).wait(); + + output->set_layout(DataLayout::kMKLDNN); + output->set_format(GetDstMemFormat(concat_pd)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConcatMKLDNNOpKernel) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 57817da71adfd80faad29a48b05ba2f326de6c07..194f9cf5033a3a73afeb8e92ddbdcc7b316fcd35 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" - #include #include +#ifdef PADDLE_WITH_MKLDNN +#include +#endif + namespace paddle { namespace operators { using framework::Tensor; @@ -59,6 +62,22 @@ class ConcatOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { @@ -66,6 +85,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddOutput("Out", "Output tensor of concat operator."); + AddAttr( + "use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 6c1b2f329a59e1b27caad2996308b33b3a72de1d..5ab0918c486cc56c7d55f24f4952a013044971ee 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -32,6 +32,20 @@ static constexpr char kStepScopes[] = "StepScopes"; static constexpr char kX[] = "X"; static constexpr char kXGRAD[] = "X@GRAD"; static constexpr char kOutputs[] = "Out"; +static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; + +namespace { // NOLINT +static std::string GetSkipEagerDeletionVarsDebugString( + const std::vector &vars) { + std::string str = "Skip " + std::to_string(vars.size()) + + " var(s) in eager deletion mode: "; + for (auto &var : vars) { + str.append(var); + str.push_back(' '); + } + return str; +} +} // NOLINT class WhileOp : public framework::OperatorBase { public: @@ -59,7 +73,10 @@ class WhileOp : public framework::OperatorBase { "Condition of while op must in CPU memory."); bool is_test = Attr("is_test"); - auto ctx = executor.Prepare(*program, block->ID()); + auto &skip_vars = Attr>(kSkipEagerDeletionVars); + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); + + auto ctx = executor.Prepare(*program, block->ID(), skip_vars); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); @@ -96,6 +113,10 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") .SetDefault(false); + AddAttr>(kSkipEagerDeletionVars, + "Vars that would skip eager deletion." + "Users should not set this manually.") + .SetDefault(std::vector()); AddComment(R"DOC( )DOC"); } @@ -119,7 +140,10 @@ class WhileGradOp : public framework::OperatorBase { framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); - auto ctx = executor.Prepare(*program, block->ID()); + + auto &skip_vars = Attr>(kSkipEagerDeletionVars); + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); + auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto *step_scopes = scope.FindVar(Input(kStepScopes))->GetMutable(); @@ -341,6 +365,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { // while operator could be renamed. while_grad->SetAttr("original_output_grad", output_grads_list); + while_grad->SetAttr(kSkipEagerDeletionVars, std::vector()); + return std::unique_ptr(while_grad); } }; diff --git a/paddle/fluid/operators/cos_sim_op.cu b/paddle/fluid/operators/cos_sim_op.cu index 82205e9c75402e368a2d1e161d471e35ff7356ea..3d144ca29d9989ad2cbb438a950860eaac873d07 100644 --- a/paddle/fluid/operators/cos_sim_op.cu +++ b/paddle/fluid/operators/cos_sim_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/cos_sim_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/crop_op.cu b/paddle/fluid/operators/crop_op.cu index b75678217e36aa2297c68a7f8e2a9dfafadaca72..66cb5c452de4b2107693127ce414daf9fb7cd7d8 100644 --- a/paddle/fluid/operators/crop_op.cu +++ b/paddle/fluid/operators/crop_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/crop_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/distributed/brpc_client.cc b/paddle/fluid/operators/distributed/brpc_client.cc index b394c678fb6503eb73a1e11e6feb814251e9e940..350969f74be258ffbfef687b56083a9c6508bc81 100644 --- a/paddle/fluid/operators/distributed/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc_client.cc @@ -158,7 +158,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { for (int i = 0; i < FLAGS_brpc_channel_num; ++i) { std::shared_ptr c(new ChannelContext()); if (c->channel.Init(ep.c_str(), &options) != 0) { - LOG(ERROR) << "Fail to initialize channel"; + LOG(FATAL) << "Fail to initialize channel"; return nullptr; } diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 857214aa211aee0251571e46049c66c084b470f1..f14dfcdb238a9580affde96e4d5a0093743eb6c8 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -390,8 +390,7 @@ void GRPCClient::Proceed() { VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - // FIXME(gongwb): parse error_details? - LOG(ERROR) << c->GetVarHandlePtr()->String() + LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error, error_code:" << c->status_.error_code() << " error_message:" << c->status_.error_message() << " error_details:" << c->status_.error_details(); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index e011f47e086183a4ef3a3373c17acd6c21b6cf7e..d65491267de1ce3495d8b8250cf0cff570dfcc6a 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include #include #include diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 2fb7eeb4b9e3119a6eea3e69a2a6002a80f6c0f3..fed12785f47e1b8eea3f053712830901bee3bdc9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index c5a1a7e08d89f3ef205af4c37246f8fa288189f3..1a149298fd33f132a90ff5de3b35dd5894a4ae68 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index a90dcd3ecf0da114110db5946e111a8b3a925e42..5d086a1b29febd8e57507eced7683f414ca34e07 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index ab77709c28c15a925bd3deac07c43e12b12cb781..cf93e5a97a3f3110aae907c593f58dbab0f9d090 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 4d16bc38e1d8e4cbbe3afbe08f233e14329e0f2e..833c4072826c58277bc23e03b787fafbbaa73d03 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 6ee0779f23bc2c734aa1d439abb12f366227e686..9263dbfebfd00451f3e67c3ca9d2081b5b4904bd 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 8d9bf7c4d81d49d83b5d1cf0369be5c9957242b4..6f17d3292f307b009c640738109d5a4f4ca4caa9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/expand_op.cu b/paddle/fluid/operators/expand_op.cu index 60363bfc86d7d1a79d7b018cee43a41c1247a994..d95c9b61802b5fe7059e1c95a50776db5aa7ad93 100644 --- a/paddle/fluid/operators/expand_op.cu +++ b/paddle/fluid/operators/expand_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/expand_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/gru_unit_op.cu b/paddle/fluid/operators/gru_unit_op.cu index fc92b3d4a7a5a933f31b21d18238de386b3afb4d..37689901ecbeeda44f52a2fc7a686f4edf6682bb 100644 --- a/paddle/fluid/operators/gru_unit_op.cu +++ b/paddle/fluid/operators/gru_unit_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/gru_unit_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/hinge_loss_op.cu b/paddle/fluid/operators/hinge_loss_op.cu index 9c0a85bee6e28865225c1848ea5a378f48932ceb..b5ea0a702e0e540c1831ca241a5def19f86c239c 100644 --- a/paddle/fluid/operators/hinge_loss_op.cu +++ b/paddle/fluid/operators/hinge_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/hinge_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/huber_loss_op.cu b/paddle/fluid/operators/huber_loss_op.cu index 659464df9dc0e7c8cd276bd0bbf7072361aa3abf..09c743c4275169ba8c53ccbd428100b2fc4483d6 100644 --- a/paddle/fluid/operators/huber_loss_op.cu +++ b/paddle/fluid/operators/huber_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/huber_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/im2sequence_op.cu b/paddle/fluid/operators/im2sequence_op.cu index e0a5a90c1c3c47ea45b3f83ae969c1861783ff60..1c34640618d58d3b5fe627fa6596260a7b687d05 100644 --- a/paddle/fluid/operators/im2sequence_op.cu +++ b/paddle/fluid/operators/im2sequence_op.cu @@ -11,8 +11,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/im2sequence_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/isfinite_op.cu b/paddle/fluid/operators/isfinite_op.cu index 8d1268b18c6fec03063051f545075209a6fcde27..995969cd42f08c7fa948262e42793106e745b3a7 100644 --- a/paddle/fluid/operators/isfinite_op.cu +++ b/paddle/fluid/operators/isfinite_op.cu @@ -11,8 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/l1_norm_op.cu b/paddle/fluid/operators/l1_norm_op.cu index 1b48571dd7378c1c2a6628662024bc7bcc08d2a6..a5c29bbf5debdd11f6e5b28b3a8b48c2c484517a 100644 --- a/paddle/fluid/operators/l1_norm_op.cu +++ b/paddle/fluid/operators/l1_norm_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/l1_norm_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/log_loss_op.cu b/paddle/fluid/operators/log_loss_op.cu index e8bf7d8159bf8b16bf4397e7765918c060124db3..280913c43a2749ddd5fbd3ae1905f1b823dd525d 100644 --- a/paddle/fluid/operators/log_loss_op.cu +++ b/paddle/fluid/operators/log_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/log_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/math/context_project.cu b/paddle/fluid/operators/math/context_project.cu index 16205c0e145ef70666d4eca564488d80bde26d2e..f04b2d15349be329ee228fc8903c9b38a5349634 100644 --- a/paddle/fluid/operators/math/context_project.cu +++ b/paddle/fluid/operators/math/context_project.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/math/context_project.h" namespace paddle { diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc index fead13ebadcd131afafc308740cdd39b1c53bc08..cb49e66488bd69d92430cbf6de1d08348ffe0202 100644 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc @@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel { } }; -#define INTRIAVX_FLOAT(isa, block) \ +#define INTRIAVX_FLOAT(isa, jit_block) \ template <> \ - LayerNormKernelImpl::LayerNormKernelImpl(int right) \ + LayerNormKernelImpl::LayerNormKernelImpl(int right) \ : LayerNormKernel() { \ this->num_ = right; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ this->end_ = this->num_ - this->rest_; \ } \ template <> \ - void LayerNormKernelImpl::Compute( \ + void LayerNormKernelImpl::Compute( \ float* x, float* out, float* mean, float* var, const float* scale, \ const float* bias, int height, const float epsilon) const { \ __m256 sum; \ @@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel { __m256 tmp; \ size_t offset; \ size_t j; \ + size_t block = YMM_FLOAT_BLOCK; \ __m256 reverse_num_vec = \ _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ @@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8); INTRIAVX_FLOAT(platform::avx, kGT8LT16); INTRIAVX_FLOAT(platform::avx, kEQ16); INTRIAVX_FLOAT(platform::avx, kGT16); -#endif -#ifdef __AVX2__ INTRIAVX_FLOAT(platform::avx2, kEQ8); INTRIAVX_FLOAT(platform::avx2, kGT8LT16); INTRIAVX_FLOAT(platform::avx2, kEQ16); INTRIAVX_FLOAT(platform::avx2, kGT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ8); +INTRIAVX_FLOAT(platform::avx512f, kGT8LT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ16); +INTRIAVX_FLOAT(platform::avx512f, kGT16); #endif #undef INTRIAVX_FLOAT diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 79b7538ad05b0ff348b8264d50b63211b5254e80..9372d63f0bea2b0c9f37d47376d7b7014e381a33 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/math/sequence2batch.cu b/paddle/fluid/operators/math/sequence2batch.cu index be73adfc0cbe37ed8831b5ad34e66bc95e342e9d..9ab13659c1cc5b59d28395bcebcfb43fac5b4544 100644 --- a/paddle/fluid/operators/math/sequence2batch.cu +++ b/paddle/fluid/operators/math/sequence2batch.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 2e9669049e36478549b793e3fa76220825888e21..71d137398267f61d8cc01907d6a9498eef8d62dc 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 413b8ace67bd0a36849373812950834523b62216..921c2e1298906655767c1e7f30dc34b2c564c671 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cu b/paddle/fluid/operators/optimizers/adadelta_op.cu index 3fbfee5df05770a1206ab3170d3baffdd20bc77b..562a157f063b44d65254d556d44439eee3636c4c 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cu +++ b/paddle/fluid/operators/optimizers/adadelta_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adadelta_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cu b/paddle/fluid/operators/optimizers/adagrad_op.cu index 4efe56855a4bdca41d24f02c29a618a8d4232887..5043468d4c5f721ae0906b1a319eb3ec10b26580 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/adagrad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/optimizers/adagrad_op.h" diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index e8090ebacfe85153aba9e275c9cd1c55fd7af15e..4eb2db717d45a730798eef48d3d10bce9d387c4b 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adam_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/adamax_op.cu b/paddle/fluid/operators/optimizers/adamax_op.cu index e54adcb142fe0d50dad23fe5df14bd6f28220d8a..80e0219d4414db2909b5babc22599d8c0d906c7d 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cu +++ b/paddle/fluid/operators/optimizers/adamax_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adamax_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu index 84d65e39329659f82099011f9ec60468d5db6328..dc568802a2b19fee5c8d7fd8d07c929cba8ab4e3 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cu b/paddle/fluid/operators/optimizers/ftrl_op.cu index f836b75df93861a0fd670f2a0e786e6a797a4661..acf8e38ca0f5a3cf9899f4898898013e8a2afdd2 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cu +++ b/paddle/fluid/operators/optimizers/ftrl_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/ftrl_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu index d1c1f747b70c3ceb806da06e6786a70b62a32995..591dead3b12763e4cd1b9c390a87816ab121fbf8 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cu b/paddle/fluid/operators/optimizers/proximal_gd_op.cu index 7aa0e1015008eba0c1cf63ba1278dc2b8049b20b..d556fa74f19529d0e2f80d4c6dbfca62498c9dcc 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/proximal_gd_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cu b/paddle/fluid/operators/optimizers/rmsprop_op.cu index 69e35a309e04f61068d9ff1b6d9f1450d2524253..8b17d6a0204045a9b20adb79dbad72dff5ba267e 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cu +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/rmsprop_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/pad_constant_like_op.cu b/paddle/fluid/operators/pad_constant_like_op.cu index ea69577904577de353b63491973bf74b7724e18e..9e62a6dc9d34a96c59a08d0e5fd6cdd9f0d6d51d 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cu +++ b/paddle/fluid/operators/pad_constant_like_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/pad_constant_like_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/pad_op.cu b/paddle/fluid/operators/pad_op.cu index 9cddef9cf1d3c43701a4f0ed3f70dcb30c1dbd02..95098a8dca36594c3af60ad8488217e71c673a75 100644 --- a/paddle/fluid/operators/pad_op.cu +++ b/paddle/fluid/operators/pad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/pad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6978d9c5dc5993e64793f420a63dcca020f47868 --- /dev/null +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -0,0 +1,173 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/psroi_pool_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +class PSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), " + "the input of PSROIPoolOp. " + "The format of input tensor is NCHW. Where N is the batch size, " + "C is the number of input channels, " + "H is the height of the input feature map, and " + "W is the width."); + AddInput("ROIs", + "(LoDTensor), " + "ROIs (Regions of Interest) to pool over. " + "should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]. " + "where (x1, y1) is the top left coordinates, and " + "(x2, y2) is the bottom right coordinates. " + "The roi batch index can be calculated from LoD."); + AddOutput("Out", + "(Tensor), " + "the output of PSROIPoolOp is a 4-D Tensor with shape " + "(num_rois, output_channels, pooled_h, pooled_w)."); + AddAttr( + "output_channels", + "(int), " + "the number of channels of the output feature map. " + "For a task of C classes of objects, output_channels should be " + "(C + 1) for classification only."); + AddAttr("spatial_scale", + "(float, default 1.0), " + "Multiplicative spatial scale factor " + "to translate ROI coords from their input scale " + "to the scale used when pooling.") + .SetDefault(1.0); + AddAttr("pooled_height", + "(int, default 1), " + "the pooled output height.") + .SetDefault(1); + AddAttr("pooled_width", + "(int, default 1), " + "the pooled output width.") + .SetDefault(1); + AddComment(R"Doc( +**PSROIPool Operator** + +Position sensitive region of interest pooling (also known as PSROIPooling) is to perform +position-sensitive average pooling on regions of interest specified by input, takes as +input N position-sensitive score maps and a list of num_rois regions of interest. + +PSROIPooling for R-FCN. Please refer to https://arxiv.org/abs/1605.06409 for more details. + )Doc"); + } +}; + +class PSROIPoolOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of PSROIPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("ROIs"), + "Input(ROIs) of PSROIPoolOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of PSROIPoolOp should not be null."); + auto input_dims = ctx->GetInputDim("X"); + auto rois_dims = ctx->GetInputDim("ROIs"); + + PADDLE_ENFORCE(input_dims.size() == 4, + "The format of input tensor is NCHW"); + PADDLE_ENFORCE(rois_dims.size() == 2, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]"); + PADDLE_ENFORCE(rois_dims[1] == 4, + "ROIs should be a 2-D LoDTensor of shape (num_rois, 4) " + "given as [(x1, y1, x2, y2), ...]"); + + int pooled_height = ctx->Attrs().Get("pooled_height"); + int pooled_width = ctx->Attrs().Get("pooled_width"); + int output_channels = ctx->Attrs().Get("output_channels"); + float spatial_scale = ctx->Attrs().Get("spatial_scale"); + + PADDLE_ENFORCE( + input_dims[1] == output_channels * pooled_height * pooled_width, + "the channel of X(%d) should be equal to the product of " + "output_channels(%d), pooled_height(%d) and pooled_width(%d)", + input_dims[1], output_channels, pooled_height, pooled_width); + + PADDLE_ENFORCE_GT(pooled_height, 0, + "The pooled output height must be greater than 0"); + PADDLE_ENFORCE_GT(pooled_width, 0, + "The pooled output width must be greater than 0"); + PADDLE_ENFORCE_GT(output_channels, 1, + "The pooled output channels must greater than 1"); + PADDLE_ENFORCE_GT(spatial_scale, 0.0f, + "The spatial scale must greater than 0."); + + auto out_dims = input_dims; + out_dims[0] = rois_dims[0]; + out_dims[1] = + output_channels; // input_dims[1] / (pooled_height * pooled_width); + out_dims[2] = pooled_height; + out_dims[3] = pooled_width; + ctx->SetOutputDim("Out", out_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +class PSROIPoolGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The gradient of Out should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "The gradient of X should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(psroi_pool, ops::PSROIPoolOp, ops::PSROIPoolOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(psroi_pool_grad, ops::PSROIPoolGradOp); +REGISTER_OP_CPU_KERNEL( + psroi_pool, + ops::CPUPSROIPoolOpKernel, + ops::CPUPSROIPoolOpKernel); +REGISTER_OP_CPU_KERNEL( + psroi_pool_grad, + ops::CPUPSROIPoolGradOpKernel, + ops::CPUPSROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/psroi_pool_op.cu b/paddle/fluid/operators/psroi_pool_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..22fec3244fabe5ca466202784c0cce372d0bf6e5 --- /dev/null +++ b/paddle/fluid/operators/psroi_pool_op.cu @@ -0,0 +1,294 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/psroi_pool_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void GPUPSROIPoolForward( + const int nthreads, const T* input_data, const T* input_rois, + const float spatial_scale, const int input_channels, const int height, + const int width, const int output_channels, const int pooled_height, + const int pooled_width, const int* rois_batch_id_data, T* output_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + const T* offset_input_data = + input_data + + (roi_batch_id * input_channels + input_channel) * height * width; + T outsum = 0; + + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + outsum += offset_input_data[input_index]; + } + } + + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + output_data[i] = is_empty ? 0. : outsum / bin_area; + } +} + +template +__global__ void GPUPSROIPoolBackward( + const int nthreads, const T* input_rois, const T* output_grad_data, + const float spatial_scale, const int input_channels, const int height, + const int width, const int output_channels, const int pooled_height, + const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_input_grad_data = input_grad_data + input_offset; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart, 0), height); + hend = min(max(hend, 0), height); + wstart = min(max(wstart, 0), width); + wend = min(max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Accumulate diff_val into input data + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area; + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + platform::CudaAtomicAdd(offset_input_grad_data + input_index, diff_val); + } + } + } +} + +template +class GPUPSROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + + PADDLE_ENFORCE_EQ(input_channels, + output_channels * pooled_height * pooled_width, + "the channels of input X should equal the product of " + "output_channels x pooled_height x pooled_width"); + + int rois_num = rois->dims()[0]; + if (rois_num == 0) return; + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "The rois_batch_size and input(X) batch_size must be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod, + "The rois_num from input and lod must be the same."); + + // set rois batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(platform::CPUPlace()); + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + framework::Tensor rois_batch_id_list_gpu; + framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &rois_batch_id_list_gpu); + + int output_size = out->numel(); + int blocks = NumBlocks(output_size); + int threads = kNumCUDAThreads; + + // call cuda kernel function + GPUPSROIPoolForward< + T><<>>( + output_size, in->data(), rois->data(), spatial_scale, + input_channels, height, width, output_channels, pooled_height, + pooled_width, rois_batch_id_list_gpu.data(), + out->mutable_data(ctx.GetPlace())); + } +}; + +template +class GPUPSROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + + auto* output_grad = ctx.Input(framework::GradVarName("Out")); + auto* input_grad = ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + int rois_num = rois->dims()[0]; + int input_channels = in->dims()[1]; + int height = in->dims()[2]; + int width = in->dims()[3]; + + if (input_grad) { + // set roi batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(platform::CPUPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + framework::Tensor rois_batch_id_list_gpu; + framework::TensorCopy(rois_batch_id_list, ctx.GetPlace(), + ctx.device_context(), &rois_batch_id_list_gpu); + + input_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.cuda_device_context(), input_grad, static_cast(0)); + + int output_grad_size = output_grad->numel(); + int blocks = NumBlocks(output_grad_size); + int threads = kNumCUDAThreads; + + if (output_grad_size > 0) { + GPUPSROIPoolBackward< + T><<>>( + output_grad_size, rois->data(), output_grad->data(), + spatial_scale, input_channels, height, width, output_channels, + pooled_height, pooled_width, rois_batch_id_list_gpu.data(), + input_grad->mutable_data(ctx.GetPlace())); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + psroi_pool, + ops::GPUPSROIPoolOpKernel, + ops::GPUPSROIPoolOpKernel); +REGISTER_OP_CUDA_KERNEL( + psroi_pool_grad, + ops::GPUPSROIPoolGradOpKernel, + ops::GPUPSROIPoolGradOpKernel); diff --git a/paddle/fluid/operators/psroi_pool_op.h b/paddle/fluid/operators/psroi_pool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..1a424728f7f6c4034242fb998d5121804e38702b --- /dev/null +++ b/paddle/fluid/operators/psroi_pool_op.h @@ -0,0 +1,253 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class CPUPSROIPoolOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto output_channels = ctx.Attr("output_channels"); + + auto in_dims = in->dims(); + int batch_size = in_dims[0]; + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + auto in_stride = framework::stride(in_dims); + auto roi_stride = framework::stride(rois->dims()); + auto out_stride = framework::stride(out->dims()); + + const T* input_data = in->data(); + + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(ctx.GetPlace()); + + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ( + rois_batch_size, batch_size, + "the rois_batch_size and input(X) batch_size should be the same."); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ(rois_num_with_lod, rois_num, + "the rois_num from input and lod must be the same"); + + PADDLE_ENFORCE_EQ(input_channels, + output_channels * pooled_height * pooled_width, + "the channels of input X should equal the product of " + "output_channels x pooled_height x pooled_width"); + + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + T* output_data = out->mutable_data(ctx.GetPlace()); + const T* input_rois = rois->data(); + + // calculate psroipooling, parallel processing can be implemented per ROI + for (int n = 0; n < rois_num; ++n) { + // set roi batch id + int roi_batch_id = rois_batch_id_data[n]; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = + static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = + static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small rois to be 1 x 1 + T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); + + // Compute bin size w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + // calculate each pixel of the output feature map. + int out_roi_offset = n * out_stride[0]; + for (int c = 0; c < output_channels; ++c) { + // per category + int out_plane_offset = out_roi_offset + c * out_stride[1]; + for (int ph = 0; ph < pooled_height; ++ph) { + int out_row_offset = out_plane_offset + ph * out_stride[2]; + for (int pw = 0; pw < pooled_width; ++pw) { + // calculate w and h at input feature map + int hstart = floor(static_cast(ph) * bin_size_h + roi_start_h); + int wstart = floor(static_cast(pw) * bin_size_w + roi_start_w); + int hend = ceil(static_cast(ph + 1) * bin_size_h + roi_start_h); + int wend = ceil(static_cast(pw + 1) * bin_size_w + roi_start_w); + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart, 0), height); + wstart = std::min(std::max(wstart, 0), width); + hend = std::min(std::max(hend, 0), height); + wend = std::min(std::max(wend, 0), width); + + int output_index = out_row_offset + pw; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_plane_offset = + roi_batch_id * in_stride[0] + input_channel * in_stride[1]; + const T* offset_input_data = input_data + input_plane_offset; + T out_sum = 0.; + bool is_empty = (hend <= hstart) || (wend <= wstart); + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * in_stride[2] + iw; + out_sum += offset_input_data[input_index]; + } + } + T bin_area = (hend - hstart) * (wend - wstart); + output_data[output_index] = is_empty ? 0. : out_sum / bin_area; + } + } + } + } + return; + } +}; + +template +class CPUPSROIPoolGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto output_channels = ctx.Attr("output_channels"); + auto spatial_scale = ctx.Attr("spatial_scale"); + + if (input_grad) { + auto in_dims = in->dims(); + int input_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int rois_num = rois->dims()[0]; + + // set roi batch id + framework::Tensor rois_batch_id_list; + rois_batch_id_list.Resize({rois_num}); + int* rois_batch_id_data = + rois_batch_id_list.mutable_data(ctx.GetPlace()); + auto rois_lod = rois->lod().back(); + int rois_batch_size = rois_lod.size() - 1; + // calculate batch id index for each roi according to LoD + for (int n = 0; n < rois_batch_size; ++n) { + for (size_t i = rois_lod[n]; i < rois_lod[n + 1]; ++i) { + rois_batch_id_data[i] = n; + } + } + + const T* input_rois = rois->data(); + const T* output_grad_data = output_grad->data(); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + + // set gradient of X to be 0. before backpropagate. + math::SetConstant set_zero; + set_zero(ctx.template device_context(), input_grad, + static_cast(0)); + + // backpropagate gradient per output pixel + int output_grad_size = output_grad->numel(); + for (int i = 0; i < output_grad_size; ++i) { + // The output is in order (n, c, ph, pw) + int pw = i % pooled_width; + int ph = (i / pooled_width) % pooled_height; + int c = (i / pooled_width / pooled_height) % output_channels; + int n = i / pooled_width / pooled_height / output_channels; + + // set roi_batch_id + int roi_batch_id = rois_batch_id_data[n]; + int input_channel = (c * pooled_height + ph) * pooled_width + pw; + int input_offset = + (roi_batch_id * input_channels + input_channel) * height * width; + T* offset_input_grad_data = input_grad_data + input_offset; + + // [start, end) interval for spatial sampling + const T* offset_input_rois = input_rois + n * 4; + T roi_start_w = + static_cast(round(offset_input_rois[0])) * spatial_scale; + T roi_start_h = + static_cast(round(offset_input_rois[1])) * spatial_scale; + T roi_end_w = + static_cast(round(offset_input_rois[2]) + 1.) * spatial_scale; + T roi_end_h = + static_cast(round(offset_input_rois[3]) + 1.) * spatial_scale; + + // Force too small ROIs to be 1x1 + T roi_height = std::max(roi_end_h - roi_start_h, (T)0.1); // avoid 0 + T roi_width = std::max(roi_end_w - roi_start_w, (T)0.1); + + // Compute w and h at input feature map + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int hstart = floor(bin_size_h * static_cast(ph) + roi_start_h); + int wstart = floor(bin_size_w * static_cast(pw) + roi_start_w); + int hend = ceil(bin_size_h * static_cast(ph + 1) + roi_start_h); + int wend = ceil(bin_size_w * static_cast(pw + 1) + roi_start_w); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart, 0), height); + hend = std::min(std::max(hend, 0), height); + wstart = std::min(std::max(wstart, 0), width); + wend = std::min(std::max(wend, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + // Accumulate diff_val into input data + T bin_area = static_cast((hend - hstart) * (wend - wstart)); + T diff_val = is_empty ? 0. : output_grad_data[i] / bin_area; + for (int ih = hstart; ih < hend; ++ih) { + for (int iw = wstart; iw < wend; ++iw) { + int input_index = ih * width + iw; + offset_input_grad_data[input_index] += diff_val; + } + } + } + } + return; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reader/ctr_reader.h b/paddle/fluid/operators/reader/ctr_reader.h index 635483158fc2892d408c419a10ae0358b9ef98de..56879ffda5d3e04a88d12d6c4701c24a0d0ee4f7 100644 --- a/paddle/fluid/operators/reader/ctr_reader.h +++ b/paddle/fluid/operators/reader/ctr_reader.h @@ -16,6 +16,7 @@ #include +#include #include // NOLINT #include #include @@ -55,8 +56,7 @@ class CTRReader : public framework::FileReader { PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!"); PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null"); PADDLE_ENFORCE_GT(file_list.size(), 0, "file list should not be empty"); - thread_num_ = - file_list_.size() > thread_num ? thread_num : file_list_.size(); + thread_num_ = std::min(file_list_.size(), thread_num); queue_ = queue; SplitFiles(); for (size_t i = 0; i < thread_num_; ++i) { @@ -96,9 +96,9 @@ class CTRReader : public framework::FileReader { VLOG(3) << "reopen success"; VLOG(3) << "thread_num " << thread_num_; for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) { - read_threads_.emplace_back(new std::thread( - std::bind(&ReadThread, file_groups_[thread_id], slots_, batch_size_, - thread_id, &read_thread_status_, queue_))); + read_threads_.emplace_back(new std::thread(std::bind( + &ReadThread, file_groups_[thread_id], slots_, batch_size_, + static_cast(thread_id), &read_thread_status_, queue_))); } monitor_thread_.reset(new std::thread( std::bind(&MonitorThread, &read_thread_status_, queue_))); diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu index 63cd47a38a0ff6413c430c6be6284c5f4bfc2595..4897474a485d8417854ffb53aa8ee64321c78ae7 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 9aadac1a416034a3510dea2916d7577efbc2f8c2..a1fbc7e5fab71df486b53c31464c99e9c4557ccd 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cu b/paddle/fluid/operators/smooth_l1_loss_op.cu index dfbb5c905884b57413587a4f6c33b0238b740c73..e5df479090fabe926f65f58e2300e3ee2027e54d 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cu +++ b/paddle/fluid/operators/smooth_l1_loss_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/smooth_l1_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 6d48796191dd13a45f0c7267bfaf05489f528a9d..cee3e87037e0f1439a08b7b275eedefe357a4b13 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index af64607fafc6544047714e731846a2440be219b8..1fef2b3d378c96d087118d0136885e7e29aa237c 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -72,10 +72,11 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { for (size_t i = 0; i < outs_rows_idx.size(); ++i) { auto rows_idx = outs_rows_idx[i]; outs[i]->set_height(height_sections[i]); + auto dims = x->GetCompleteDims(); + dims[0] = rows_idx.size(); + outs[i]->mutable_value()->mutable_data(dims, x->place()); + outs[i]->mutable_rows()->clear(); if (rows_idx.size() > 0) { - auto dims = x->GetCompleteDims(); - dims[0] = rows_idx.size(); - outs[i]->mutable_value()->mutable_data(dims, x->place()); for (auto idx : rows_idx) { outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); } @@ -98,6 +99,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { } } } + PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(), + "rows should has the same size with tensor dim 0"); } } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cu b/paddle/fluid/operators/squared_l2_distance_op.cu index 3e80ae8dd22077c0f9bbdedc24e84f6c339c5a26..c9264da838246efded7d9f85664faf0dc1cec282 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cu +++ b/paddle/fluid/operators/squared_l2_distance_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/squared_l2_distance_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/squared_l2_norm_op.cu b/paddle/fluid/operators/squared_l2_norm_op.cu index 87830413da3f141f01a97966ae0e2b0501ed600a..e31cfeb78ab8a8d1b55a198fe7a2c647a3dce665 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.cu +++ b/paddle/fluid/operators/squared_l2_norm_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/squared_l2_norm_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index db4c2d6c115f04b436db00854ca4b02fea09866b..6125ed07b6d0f92fa317c581a06117dcfa7359ae 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 93cb5eb2dc0b3480ebd05dcc6b36d8915d057bab..23c7ebe84221986a5f7ac7583c3a8e17d04fe4af 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -56,9 +56,16 @@ ELSE() set(MKLDNN_CTX_DEPS) ENDIF() +nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce) +IF(WITH_GPU) + set(STREAM_CALLBACK_DEPS stream_callback_manager) +ELSE() + set(STREAM_CALLBACK_DEPS) +ENDIF() + # memcpy depends on device_context, here add deps individually for # avoiding cycle dependencies -cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc +cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS} place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_info) diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index ee45afab93d079374aefe366425502890854c28d..466bf90c63c1496883995819cdcb19f846e4a302 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -93,7 +93,7 @@ TEST(CudaAtomic, float16) { // unalignment of uint8 void TestUnalign(size_t num, const int shift_bit) { - PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); + ASSERT_EQ(num % 2, 0); float16 *in1, *in2, *out; float16 *d_in1, *d_in2; size_t size = sizeof(uint8_t) * (num + shift_bit); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 146a205832fd31faf01effdce01eb2884f9a9884..bd81d4dd1f1073edffcb9fd4a02b455db27361d5 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -220,6 +220,40 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) LOG_FIRST_N(WARNING, 1) << "device: " << place_.device << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "." << (cudnn_dso_ver % 100) / 10 << "."; + + { + // Check CUDA/CUDNN version compatiblity + auto local_cuda_version = runtime_version_ / 100; + auto compile_cuda_version = CUDA_VERSION / 100; + if (local_cuda_version < compile_cuda_version) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << place_.device + << ". The installed Paddle is compiled with CUDA " + << compile_cuda_version / 10 << "." << compile_cuda_version % 10 + << ", but CUDA runtime version in your machine is " + << local_cuda_version / 10 << "." << local_cuda_version % 10 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible CUDA " + "version."; + } + + if (dynload::HasCUDNN()) { + auto local_cudnn_version = cudnn_dso_ver / 100; + auto compile_cudnn_version = CUDNN_VERSION / 100; + if (local_cuda_version < compile_cuda_version) { + LOG_FIRST_N(WARNING, 1) + << "WARNING: device: " << place_.device + << ". The installed Paddle is compiled with CUDNN " + << compile_cudnn_version / 10 << "." << compile_cudnn_version % 10 + << ", but CUDNN version in your machine is " + << local_cudnn_version / 10 << "." << local_cudnn_version % 10 + << ", which may cause serious incompatible bug. " + << "Please recompile or reinstall Paddle with compatible CUDNN " + "version."; + } + } + } + callback_manager_.reset(new StreamCallbackManager(stream_)); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3edd727978010e20203ab994562ce922b6ee0bad..812e56f1f966d03207cf83ad47cb88e9fa5d55bb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" -#define EIGEN_USE_GPU #endif #ifdef PADDLE_WITH_MKLDNN @@ -223,14 +222,10 @@ class CUDADeviceContext : public DeviceContext { template void AddStreamCallback(Callback&& callback) const { - std::lock_guard guard(callback_mtx_); callback_manager_->AddCallback(callback); } - void WaitStreamCallback() const { - std::lock_guard guard(callback_mtx_); - callback_manager_->Wait(); - } + void WaitStreamCallback() const { callback_manager_->Wait(); } #if CUDA_VERSION >= 9000 /*! \brief CublasCall may need to change cublas's config, @@ -261,9 +256,7 @@ class CUDADeviceContext : public DeviceContext { mutable std::mutex mtx_; - // This lock is only used by callback - // If we use mtx_ for StreamCallbackManager, deadlock may occur sometimes - mutable std::mutex callback_mtx_; + // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; mutable std::mutex cublas_mtx_; diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a85972bdb72ca3119cc14f9e2b810c3875443538..01ee67fd07f848356e801be95d53a61bb5b08e37 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -62,45 +62,54 @@ inline std::string demangle(std::string name) { return name; } #endif struct EnforceNotMet : public std::exception { - std::exception_ptr exp_; std::string err_str_; - EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { - static constexpr int TRACE_STACK_LIMIT = 100; + EnforceNotMet(std::exception_ptr e, const char* f, int l) { try { - std::rethrow_exception(exp_); - } catch (const std::exception& exp) { - std::ostringstream sout; + std::rethrow_exception(e); + } catch (std::exception& e) { + Init(e.what(), f, l); + } + } - sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; - sout << "PaddlePaddle Call Stacks: " << std::endl; + template + EnforceNotMet(const char* f, int l, ARGS... args) { + Init(string::Sprintf(args...), f, l); + } + + const char* what() const noexcept override { return err_str_.c_str(); } + + private: + template + inline void Init(StrType what, const char* f, int l) { + static constexpr int TRACE_STACK_LIMIT = 100; + std::ostringstream sout; + + sout << string::Sprintf("%s at [%s:%d]", what, f, l) << std::endl; + sout << "PaddlePaddle Call Stacks: " << std::endl; #if !defined(_WIN32) - void* call_stack[TRACE_STACK_LIMIT]; - auto size = backtrace(call_stack, TRACE_STACK_LIMIT); - auto symbols = backtrace_symbols(call_stack, size); - - Dl_info info; - for (int i = 0; i < size; ++i) { - if (dladdr(call_stack[i], &info) && info.dli_sname) { - auto demangled = demangle(info.dli_sname); - auto addr_offset = static_cast(call_stack[i]) - - static_cast(info.dli_saddr); - sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, - 2 + sizeof(void*) * 2, call_stack[i], - demangled, addr_offset); - } else { - sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, - call_stack[i]); - } + void* call_stack[TRACE_STACK_LIMIT]; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + Dl_info info; + for (int i = 0; i < size; ++i) { + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = demangle(info.dli_sname); + auto addr_offset = static_cast(call_stack[i]) - + static_cast(info.dli_saddr); + sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, + 2 + sizeof(void*) * 2, call_stack[i], demangled, + addr_offset); + } else { + sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, + call_stack[i]); } - free(symbols); + } + free(symbols); #else - sout << "Windows not support stack backtrace yet."; + sout << "Windows not support stack backtrace yet."; #endif - err_str_ = sout.str(); - } + err_str_ = sout.str(); } - - const char* what() const noexcept { return err_str_.c_str(); } }; struct EOFException : public std::exception { @@ -242,13 +251,8 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - std::make_exception_ptr( \ - std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ - __FILE__, __LINE__); \ - } while (false) +#define PADDLE_THROW(...) \ + throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__) #ifndef REPLACE_ENFORCE_GLOG #define PADDLE_ENFORCE(...) \ diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 9d48557caf75f3571ead3df43a1a93cf65e4b8cb..98afe843c0035ec14ad874508dc02b8d1d3d359c 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -71,9 +71,6 @@ struct float16; } // namespace platform } // namespace paddle -// NOTE(): -// Do not move the eigen.h header, otherwise the eigen_vector will failed. -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/hostdevice.h" #include "unsupported/Eigen/CXX11/Tensor" diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..466c77469ef256179c52442d21c1d62dfc4ef1bb --- /dev/null +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/stream_callback_manager.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +#if CUDA_VERSION >= 10000 +static void CUDART_CB StreamCallbackFunc(void *user_data); +#else +static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, + cudaError_t status, void *user_data) +#endif +{ + std::unique_ptr> func( + reinterpret_cast *>(user_data)); + (*func)(); +} + +StreamCallbackManager::StreamCallbackManager(const cudaStream_t stream) + : stream_(stream), thread_pool_(1) {} + +void StreamCallbackManager::AddCallback(std::function callback) const { + auto *callback_func = new std::function(std::move(callback)); + auto *func = new std::function([this, callback_func] { + std::lock_guard lock(mtx_); + last_future_ = thread_pool_.enqueue([callback_func] { + std::unique_ptr> releaser(callback_func); + (*callback_func)(); + }); + }); +#if CUDA_VERSION >= 10000 + PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, StreamCallbackFunc, func)); +#else + PADDLE_ENFORCE(cudaStreamAddCallback(stream_, StreamCallbackFunc, func, 0)); +#endif +} + +void StreamCallbackManager::Wait() const { + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); + { + std::lock_guard lock(mtx_); + if (last_future_.valid()) { + last_future_.wait(); + } + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/stream_callback_manager.h b/paddle/fluid/platform/stream_callback_manager.h index ed8734c98cb996721a3002523c9276e6d7f492ae..8668bcb1131719e882ecbccb08ad00b63409eb28 100644 --- a/paddle/fluid/platform/stream_callback_manager.h +++ b/paddle/fluid/platform/stream_callback_manager.h @@ -18,67 +18,32 @@ #include #include #include +#include // NOLINT #include +#include // NOLINT + #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace platform { -class StreamCallbackManager; - -struct StreamCallbackContext { - template - inline StreamCallbackContext(const StreamCallbackManager *manager, - Callback &&callback) - : manager_(manager), callback_(callback) {} - - const StreamCallbackManager *manager_; // do not own - std::function callback_; -}; - +// NOTE(zjl): clean StreamCallbackManager to make compilation faster +// Make StreamCallbackManager thread-safe class StreamCallbackManager { public: - explicit inline StreamCallbackManager(cudaStream_t stream = nullptr) - : stream_(stream), thread_pool_(new ThreadPool(1)) {} + explicit StreamCallbackManager(const cudaStream_t stream); + + ~StreamCallbackManager() = default; - template - inline void AddCallback(Callback &&callback) const { - auto *stream_callback_context = - new StreamCallbackContext(this, std::forward(callback)); -#if CUDA_VERSION >= 10000 - PADDLE_ENFORCE(cudaLaunchHostFunc(stream_, - StreamCallbackManager::StreamCallbackFunc, - stream_callback_context)); // NOLINT -#else - PADDLE_ENFORCE(cudaStreamAddCallback( - stream_, StreamCallbackManager::StreamCallbackFunc, - stream_callback_context, 0)); // NOLINT -#endif - } + void AddCallback(std::function callback) const; - void Wait() const { thread_pool_.reset(new ThreadPool(1)); } + void Wait() const; private: const cudaStream_t stream_; - mutable std::unique_ptr thread_pool_; - -// cudaStreamCallback cannot call CUDA API inside, so we have to use -// thread_pool here -#if CUDA_VERSION >= 10000 - static void CUDART_CB StreamCallbackFunc(void *user_data) -#else - static void CUDART_CB StreamCallbackFunc(cudaStream_t stream, - cudaError_t status, void *user_data) -#endif - { - auto *callback_context_ptr = - reinterpret_cast(user_data); - callback_context_ptr->manager_->thread_pool_->enqueue([=]() { - std::unique_ptr callback_context( - callback_context_ptr); - callback_context->callback_(); - }); - } + mutable ::ThreadPool thread_pool_; + mutable std::mutex mtx_; + mutable std::future last_future_; }; } // namespace platform diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index d602613fc82223e14f48830a87533880696eb550..b8954cb12628d1f4f333956e0213ddf9c01e592c 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,7 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler) -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc) +set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc) + if(WITH_PYTHON) if(WITH_AMD_GPU) hip_library(paddle_pybind SHARED diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc new file mode 100644 index 0000000000000000000000000000000000000000..34e9c897d9e95feb185083b7c0a6a824d8dc809c --- /dev/null +++ b/paddle/fluid/pybind/imperative.cc @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/pybind/imperative.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/imperative/tracer.h" + +namespace paddle { +namespace pybind { + +// Bind Methods +void BindTracer(pybind11::module *m) { + pybind11::class_(*m, "Tracer", "") + .def("__init__", + [](imperative::Tracer &self, framework::BlockDesc *root_block) { + new (&self) imperative::Tracer(root_block); + }) + .def("trace", &imperative::Tracer::Trace) + .def("get_scope", &imperative::Tracer::GetScope, + pybind11::return_value_policy::reference); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/imperative.h b/paddle/fluid/pybind/imperative.h new file mode 100644 index 0000000000000000000000000000000000000000..7a9d3a01ea81f11ac85000c3d0153f20e108789a --- /dev/null +++ b/paddle/fluid/pybind/imperative.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include "paddle/fluid/imperative/layer.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace paddle { +namespace pybind { + +class PyLayer : public imperative::Layer { + public: + using imperative::Layer::Layer; // Inherit constructors + + std::vector Forward( + const std::vector& inputs) override { + PYBIND11_OVERLOAD(std::vector, Layer, Forward, + inputs); // NOLINT + } + + void Backward() override { + PYBIND11_OVERLOAD(void, Layer, Backward, ); // NOLINT + } +}; + +class PyOpBase : public imperative::OpBase { + public: + using imperative::OpBase::OpBase; // Inherit constructors +}; + +class PyVarBase : public imperative::VarBase { + public: + using imperative::VarBase::VarBase; // Inherit constructors +}; + +void BindTracer(pybind11::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3beb93e7b3ebe83e5b5c82afad280512f043e175..1fa91114a8cda4080490a0a99e2fc3f891b30417 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/version.h" +#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" @@ -45,6 +46,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/recordio.h" @@ -100,6 +102,42 @@ PYBIND11_MODULE(core, m) { BindException(&m); + py::class_(m, "VarBase", R"DOC()DOC") + .def(py::init<>()) + .def("_run_backward", + [](imperative::VarBase &self, framework::Scope *scope) { + self.RunBackward(scope); + }) + .def("_grad", &imperative::VarBase::Grad) + .def_property( + "desc", + [](const imperative::VarBase &self) { return self.var_desc_; }, + [](imperative::VarBase &self, framework::VarDesc *var_desc) { + self.var_desc_ = var_desc; + }, + py::return_value_policy::reference); + + py::class_(m, "OpBase", R"DOC()DOC") + .def(py::init<>()) + .def_property( + "desc", [](const imperative::OpBase &self) { return self.op_desc_; }, + [](imperative::OpBase &self, framework::OpDesc *op_desc) { + if (op_desc) { + self.op_desc_ = op_desc; + } + }, + py::return_value_policy::reference); + + py::class_ layer(m, "Layer"); + layer.def(py::init<>()) + .def("forward", + [](imperative::Layer &self, + const std::vector &inputs) { + return self.Forward(inputs); + }) + .def("backward", &imperative::Layer::Backward); + BindTracer(&m); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) @@ -298,6 +336,8 @@ PYBIND11_MODULE(core, m) { .def("get_tensor", [](SelectedRows &self) { return self.mutable_value(); }, py::return_value_policy::reference) + .def("numel", + [](SelectedRows &self) -> int64_t { return self.value().numel(); }) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) .def("set_rows", @@ -601,6 +641,7 @@ All parameter, weight, gradient are variables in Paddle. m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); + m.def("get_variable_tensor", framework::GetVariableTensor); m.def("_is_program_version_supported", IsProgramVersionSupported); diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 02a75236f6c7c7a64f2aa110ca7a7e3d92832fe9..24800e17098759082fd047e51a10fa40ff48b961 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -162,7 +162,7 @@ void PyCPUTensorSetFromArray( paddle::platform::CPUPlace place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } @@ -182,7 +182,7 @@ inline void PyCPUTensorSetFromArray( paddle::platform::CPUPlace place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } @@ -200,7 +200,7 @@ void PyCUDATensorSetFromArray( paddle::platform::CUDAPlace place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } @@ -221,7 +221,7 @@ inline void PyCUDATensorSetFromArray( paddle::platform::CUDAPlace place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } @@ -240,7 +240,7 @@ void PyCUDAPinnedTensorSetFromArray( const paddle::platform::CUDAPinnedPlace &place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } @@ -260,7 +260,7 @@ inline void PyCUDAPinnedTensorSetFromArray( const paddle::platform::CUDAPinnedPlace &place) { std::vector dims; dims.reserve(array.ndim()); - for (size_t i = 0; i < array.ndim(); ++i) { + for (decltype(array.ndim()) i = 0; i < array.ndim(); ++i) { dims.push_back(static_cast(array.shape()[i])); } diff --git a/python/paddle/dataset/image.py b/python/paddle/dataset/image.py index 19fc229e6fa84792f58aeeb00be09eb2401b19c7..57547f1867a937d16fb2dfc9b84e1a30759a527e 100644 --- a/python/paddle/dataset/image.py +++ b/python/paddle/dataset/image.py @@ -32,11 +32,28 @@ the image layout as follows. from __future__ import print_function +import six import numpy as np -try: - import cv2 -except ImportError: - cv2 = None +# FIXME(minqiyang): this is an ugly fix for the numpy bug reported here +# https://github.com/numpy/numpy/issues/12497 +if six.PY3: + import subprocess + import sys + import_cv2_proc = subprocess.Popen( + [sys.executable, "-c", "import cv2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + out, err = import_cv2_proc.communicate() + retcode = import_cv2_proc.poll() + if retcode != 0: + cv2 = None + else: + import cv2 +else: + try: + import cv2 + except ImportError: + cv2 = None import os import tarfile import six.moves.cPickle as pickle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 2a53519188e7454b54424cfdd4a713ae37a2326b..e0bb0d1152b258f9b16bb9063acf0ac6012d3432 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -34,6 +34,7 @@ from . import io from . import evaluator from . import initializer from . import layers +from . import imperative from . import contrib from . import nets from . import optimizer @@ -67,6 +68,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'initializer', 'layers', 'contrib', + 'imperative', 'transpiler', 'nets', 'optimizer', @@ -124,8 +126,9 @@ def __bootstrap__(): 'check_nan_inf', 'benchmark', 'eager_delete_scope', 'use_mkldnn', 'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size", - 'eager_delete_tensor_gb', 'allocator_strategy', - 'reader_queue_speed_test_mode', 'print_sub_graph_dir' + 'eager_delete_tensor_gb', 'fast_eager_deletion_mode', + 'allocator_strategy', 'reader_queue_speed_test_mode', + 'print_sub_graph_dir', 'pe_profile_fname' ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/average.py b/python/paddle/fluid/average.py index 42cd3b36420ef5a17a9a7d981978ba8869809936..40a734af311e2037c1816dce97db123ebedd2f4f 100644 --- a/python/paddle/fluid/average.py +++ b/python/paddle/fluid/average.py @@ -48,6 +48,7 @@ class WeightedAverage(object): Examples: .. code-block:: python + avg = fluid.average.WeightedAverage() avg.add(value=2.0, weight=1) avg.add(value=4.0, weight=2) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a40826168dc21e3eec0050c2be7afc1dc74e8e5b..089792059465c60da43d02e8389f4e36900c2292 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -18,6 +18,7 @@ import collections import contextlib import re import six +import sys import numpy as np @@ -49,6 +50,16 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() +_imperative_tracer_ = None + + +def _in_imperative_mode(): + return _imperative_tracer_ is not None + + +def _imperative_tracer(): + return _imperative_tracer_ + class NameScope(object): def __init__(self, name="", parent=None): @@ -345,6 +356,21 @@ class Variable(object): self.op = None self.stop_gradient = stop_gradient self.is_data = is_data + if _in_imperative_mode(): + self._ivar = core.VarBase() + self._ivar.desc = self.desc + + def _numpy(self): + scope = _imperative_tracer().get_scope(self.block.desc) + tensor = core.get_variable_tensor(scope, self.desc.name()) + return np.array(tensor) + + def _backward(self): + scope = _imperative_tracer().get_scope(self.block.desc) + self._ivar._run_backward(scope) + + def _gradient(self): + return np.array(self._ivar._grad()) def __str__(self): return self.to_string(True) @@ -655,6 +681,23 @@ class Operator(object): if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc) + if _in_imperative_mode(): + self.iop = core.OpBase() + self.iop.desc = self.desc + self.inputs = [] + if inputs is not None: + for inp in inputs.values(): + if isinstance(inp, Variable): + self.inputs.append(inp) + elif isinstance(inp, list) or isinstance(inp, tuple): + self.inputs.extend(inp[:]) + self.outputs = [] + if outputs is not None: + for out in outputs.values(): + if isinstance(out, Variable): + self.outputs.append(out) + elif isinstance(out, list) or isinstance(out, tuple): + self.outputs.extend(out[:]) def _has_kernel(self, op_type): return op_type not in self.OP_WITHOUT_KERNEL_SET @@ -1041,19 +1084,15 @@ class Block(object): raise ValueError("var %s not in this block" % name) return v - def _var_recursive(self, name): + def _find_var_recursive(self, name): """ Get a Variable by name from this block recursively. Args: name(str): the Variable's name. - Raises: - ValueError: this block and this parent block doesn't - have a Variable with the giving name. - Returns: - Variable: the Variable with the giving name. + Variable: the Variable with the giving name. Or None if not found. """ frontier = list() visited = set() @@ -1079,8 +1118,27 @@ class Block(object): frontier.append(prog.block(cur.forward_block_idx)) visited.add(id(cur)) + return None - raise ValueError("Var {0} is not found recursively".format(name)) + def _var_recursive(self, name): + """ + Get a Variable by name from this block recursively. + + Args: + name(str): the Variable's name. + + Raises: + ValueError: this block and this parent block doesn't + have a Variable with the giving name. + + Returns: + Variable: the Variable with the giving name. + """ + var = self._find_var_recursive(name) + if var: + return var + else: + raise ValueError("Var {0} is not found recursively".format(name)) def all_parameters(self): return list(self.iter_parameters()) @@ -1206,6 +1264,9 @@ class Block(object): """ op_desc = self.desc.append_op() op = Operator(block=self, desc=op_desc, *args, **kwargs) + if _in_imperative_mode(): + _imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs], + [v._ivar for v in op.outputs], self.desc) self.ops.append(op) return op @@ -2210,3 +2271,12 @@ def _get_var(name, program=None): assert isinstance(program, Program) return program.global_block().var(name) + + +@contextlib.contextmanager +def _imperative_guard(tracer): + global _imperative_tracer_ + tmp_trace = _imperative_tracer_ + _imperative_tracer_ = tracer + yield + _imperative_tracer_ = tmp_trace diff --git a/python/paddle/fluid/imperative/__init__.py b/python/paddle/fluid/imperative/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..922308b6b18b335535d41f24d544cde04991b794 --- /dev/null +++ b/python/paddle/fluid/imperative/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from . import base +from .base import * + +from . import layers +from .layers import * + +__all__ = [] +__all__ += layers.__all__ +__all__ += base.__all__ diff --git a/python/paddle/fluid/imperative/base.py b/python/paddle/fluid/imperative/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15d38ddb56c71ef7de67f79cf52cd26070f470cb --- /dev/null +++ b/python/paddle/fluid/imperative/base.py @@ -0,0 +1,56 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import numpy as np + +from paddle.fluid import core +from paddle.fluid import framework + +__all__ = ['enabled', 'guard', 'to_variable'] + + +def enabled(): + return framework._in_imperative_mode() + + +@contextlib.contextmanager +def guard(): + train = framework.Program() + startup = framework.Program() + tracer = core.Tracer(train.current_block().desc) + with framework.program_guard(train, startup): + with framework.unique_name.guard(): + with framework._imperative_guard(tracer): + yield + + +def to_variable(value, block=None): + if isinstance(value, np.ndarray): + if not block: + block = framework.default_main_program().current_block() + py_var = framework.Variable( + block, + type=core.VarDesc.VarType.LOD_TENSOR, + name=None, + shape=value.shape, + dtype=value.dtype) + scope = framework._imperative_tracer().get_scope(block.desc) + var = scope.var(py_var.name) + tensor = var.get_tensor() + tensor.set(value, core.CPUPlace()) + return py_var + elif isinstance(value, framework.Variable): + return value + else: + raise ValueError("Unsupported type %s" % type(value)) diff --git a/python/paddle/fluid/imperative/layers.py b/python/paddle/fluid/imperative/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..1a28f7f4ae35295394b560d79e3dc0cdd5f2beab --- /dev/null +++ b/python/paddle/fluid/imperative/layers.py @@ -0,0 +1,44 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import sys +import numpy as np + +from paddle.fluid import core +from paddle.fluid import framework +from paddle.fluid.imperative import base + +__all__ = ['PyLayer'] + + +class PyLayer(core.Layer): + def __init__(self): + pass + + def __call__(self, inputs): + # TODO(panyx0718): Support declarative mode as well. + assert base.enabled() + if not isinstance(inputs, list) and not isinstance(inputs, tuple): + inputs = [inputs] + + var_inputs = [] + for x in inputs: + py_var = base.to_variable(x) + var_inputs.append(py_var) + outputs = self.forward(var_inputs) + return outputs + + def forward(self, inputs): + return [] diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index dc317de9abbd06f4021e64b87ea88ba6af8809c9..74b4a977db6b69d4d256e1f7b36eb53524269bb1 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -17,10 +17,13 @@ from __future__ import print_function import copy import itertools import six +import sys +import numpy as np from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating from . import unique_name from paddle.fluid.initializer import Constant, Xavier +from paddle.fluid.imperative import base from .param_attr import ParamAttr, WeightNormParamAttr from . import core from six.moves import zip @@ -46,23 +49,21 @@ class LayerHelper(object): def startup_program(self): return default_startup_program() + def to_variable(self, x): + return base.to_variable(x, self.main_program.current_block()) + def append_op(self, *args, **kwargs): return self.main_program.current_block().append_op(*args, **kwargs) def multiple_input(self, input_param_name='input'): inputs = self.kwargs.get(input_param_name, []) - type_error = TypeError( - "Input of {0} layer should be Variable or sequence of Variable". - format(self.layer_type)) - if isinstance(inputs, Variable): - inputs = [inputs] - elif not isinstance(inputs, list) and not isinstance(inputs, tuple): - raise type_error + ret = [] + if isinstance(inputs, list) or isinstance(inputs, tuple): + for inp in inputs: + ret.append(self.to_variable(inp)) else: - for each in inputs: - if not isinstance(each, Variable): - raise type_error - return inputs + ret.append(self.to_variable(inputs)) + return ret def input(self, input_param_name='input'): inputs = self.multiple_input(input_param_name) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 05138bf94598f649ef7fdbaa94896b6ba0884416..b7e39685691809d04ecddc21d2d04a7a85e478d5 100644 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -717,8 +717,9 @@ class While(object): out_vars = [] for inner_out_name in inner_outputs: - if inner_out_name in parent_block.vars: - out_vars.append(parent_block.var(inner_out_name)) + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_vars.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) @@ -1264,10 +1265,11 @@ class ConditionalBlock(object): if each_name not in input_set ] - out_list = [ - parent_block.var(var_name) for var_name in parent_block.vars - if var_name in intermediate - ] + out_list = [] + for inner_out_name in intermediate: + inner_var = parent_block._find_var_recursive(inner_out_name) + if inner_var: + out_list.append(inner_var) step_scope = parent_block.create_var( type=core.VarDesc.VarType.STEP_SCOPES) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4833212d311e4792dd14709cf3e5843e297e810d..3832cae8c3564447dd2bb8d177c5c4ad9cd9ccd6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -41,6 +41,7 @@ __all__ = [ 'crf_decoding', 'cos_sim', 'cross_entropy', + 'bpr_loss', 'square_error_cost', 'chunk_eval', 'sequence_conv', @@ -172,6 +173,7 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'lstm', + 'psroi_pool', ] kIgnoreIndex = -100 @@ -1348,6 +1350,44 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): return out +def bpr_loss(input, label, name=None): + """ + Bayesian Personalized Ranking Loss Operator. + + This operator belongs to pairwise ranking loss. Label is the desired item. + The loss at a given point in one session is defined as: + $Y[i] = -\frac{1}{N_{i}-1} * \sum_{0\le j(https://arxiv.org/abs/1511.06939) + + Args: + input (Variable|list): a 2-D tensor with shape [N x D], where N is the + batch size and D is the number of classes. + This input is not probability but logits. + label (Variable|list): the ground truth which is a 2-D tensor. `label` + is a tensor with shape [N x 1]. + name (str|None): A name for this layer(optional). If set None, the + layer will be named automatically. Default: None. + Returns: + A 2-D tensor with shape [N x 1], the bpr loss. + + Examples: + .. code-block:: python + + cost = fluid.layers.bpr_loss(input=predict, label=label) + """ + + helper = LayerHelper('bpr_loss', **locals()) + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='bpr_loss', + inputs={'X': [input], + 'Label': [label]}, + outputs={'Y': [out]}) + return out + + def square_error_cost(input, label): """ **Square error cost layer** @@ -6623,7 +6663,8 @@ def relu(x, name=None): helper = LayerHelper('relu', **locals()) dtype = helper.input_dtype(input_param_name='x') out = helper.create_variable_for_type_inference(dtype) - helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out}) + helper.append_op( + type="relu", inputs={"X": helper.input('x')}, outputs={"Out": out}) return out @@ -9082,3 +9123,57 @@ def get_tensor_from_selected_rows(x, name=None): outputs={'Out': out}, attrs={}) return out + + +@templatedoc() +def psroi_pool(input, + rois, + output_channels, + spatial_scale, + pooled_height, + pooled_width, + name=None): + """ + ${comment} + + Args: + input (Variable): ${x_comment} + rois (Variable): ROIs (Regions of Interest) to pool over. + output_channels (integer): ${output_channels_comment} + spatial_scale (float): ${spatial_scale_comment} Default: 1.0 + pooled_height (integer): ${pooled_height_comment} Default: 1 + pooled_width (integer): ${pooled_width_comment} Default: 1 + name (str, default None): The name of this layer. + + Returns: + Variable: ${out_comment}. + + Examples: + .. code-block:: python + + pool_out = fluid.layers.psroi_pool(input=x, rois=rois, 490, 1.0, 7, 7) + """ + helper = LayerHelper('psroi_pool', **locals()) + # check attrs + if not isinstance(output_channels, int): + raise TypeError("output_channels must be int type") + if not isinstance(spatial_scale, float): + raise TypeError("spatial_scale must be float type") + if not isinstance(pooled_height, int): + raise TypeError("pooled_height must be int type") + if not isinstance(pooled_width, int): + raise TypeError("pooled_width must be int type") + dtype = helper.input_dtype() + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='psroi_pool', + inputs={'X': input, + 'ROIs': rois}, + outputs={'Out': out}, + attrs={ + 'output_channels': output_channels, + 'spatial_scale': spatial_scale, + 'pooled_height': pooled_height, + 'pooled_width': pooled_width + }) + return out diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 1cda2711f765622b0bda6f4c688f69352bbd2a6f..1c45a10a9ddde743dce9b343e4d18f568bb05e72 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -93,7 +93,7 @@ class TestDistMnist2x2(TestDistRunnerBase): # TODO(typhoonzero): fix distributed adam optimizer # opt = fluid.optimizer.AdamOptimizer( # learning_rate=0.001, beta1=0.9, beta2=0.999) - opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) # Reader train_reader = paddle.batch( diff --git a/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c8dc5fbd237d17f2d4e45b06e5806fff5cbf58fe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bpr_loss_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, randomize_probability + + +class TestBprLossOp1(OpTest): + """Test BprLoss with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "bpr_loss" + batch_size = 40 + class_num = 5 + X = randomize_probability(batch_size, class_num, dtype='float64') + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64") + bpr_loss_result = [] + for i in range(batch_size): + sum = 0.0 + for j in range(class_num): + if j == label[i][0]: + continue + sum += (-np.log(1.0 + np.exp(X[i][j] - X[i][label[i][0]]))) + bpr_loss_result.append(-sum / (class_num - 1)) + bpr_loss = np.asmatrix([[x] for x in bpr_loss_result], dtype="float64") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": bpr_loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2130f9049c7ee294444282e59c654551f76603 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py @@ -0,0 +1,61 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 + + +class TestMKLDNNConcatOp(TestConcatOp): + def setUp(self): + super(TestMKLDNNConcatOp, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNConcatOp2(TestConcatOp2): + def setUp(self): + super(TestMKLDNNConcatOp2, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNConcatOp3(TestConcatOp3): + def setUp(self): + super(TestMKLDNNConcatOp3, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 0a43f536585ad72184e067b585ac8ec326a2e842..cedb3383ed4728306c61d7f987850000506457c7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -32,7 +32,7 @@ DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=DEFAULT_BATCH_SIZE): + def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -56,6 +56,7 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): + self.lr = args.lr self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, @@ -71,6 +72,7 @@ class TestDistRunnerBase(object): exe.run(pserver_prog) def run_trainer(self, args): + self.lr = args.lr test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=args.batch_size) @@ -189,6 +191,7 @@ def runtime_main(test_class): parser.add_argument( '--use_reader_alloc', action='store_true', required=False) parser.add_argument('--batch_size', required=False, type=int, default=2) + parser.add_argument('--lr', required=False, type=float, default=0.001) parser.add_argument( '--batch_merge_repeat', required=False, type=int, default=1) @@ -234,6 +237,7 @@ class TestDistBase(unittest.TestCase): self._dc_asgd = False # must use with async mode self._use_reader_alloc = True self._nccl2_mode = False + self._lr = 0.001 self._setup_config() self._after_setup_config() @@ -284,7 +288,8 @@ class TestDistBase(unittest.TestCase): batch_size=DEFAULT_BATCH_SIZE, batch_merge_repeat=1): - cmd = "%s %s --role trainer" % (self._python_interp, model) + cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model, + self._lr) if batch_size != DEFAULT_BATCH_SIZE: cmd += " --batch_size %d" % batch_size if batch_merge_repeat > 1: @@ -330,13 +335,13 @@ class TestDistBase(unittest.TestCase): ps0_ep, ps1_ep = self._ps_endpoints.split(",") - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f" tr0_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 0, ps0_ep, self._trainers) + 0, ps0_ep, self._trainers, self._lr) tr1_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 1, ps1_ep, self._trainers) + 1, ps1_ep, self._trainers, self._lr) if self._sync_mode: tr0_cmd += " --sync_mode" @@ -378,6 +383,18 @@ class TestDistBase(unittest.TestCase): stderr=tr1_pipe, env=env1) + # Wait until trainer process terminate + while True: + stat0 = tr0_proc.poll() + time.sleep(0.1) + if stat0 is not None: + break + while True: + stat1 = tr1_proc.poll() + time.sleep(0.1) + if stat1 is not None: + break + tr0_out, tr0_err = tr0_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate() @@ -390,11 +407,21 @@ class TestDistBase(unittest.TestCase): ps0.terminate() ps1.terminate() + # print server log + with open("/tmp/ps0_err.log", "r") as fn: + sys.stderr.write("ps0 stderr: %s\n" % fn.read()) + with open("/tmp/ps1_err.log", "r") as fn: + sys.stderr.write("ps1 stderr: %s\n" % fn.read()) + # print log - sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out)) - sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) - sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out)) - sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + if stat0 == 0: + sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out)) + with open("/tmp/tr0_err.log", "r") as fn: + sys.stderr.write('trainer 0 stderr: %s\n' % fn.read()) + if stat1 == 0: + sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out)) + with open("/tmp/tr1_err.log", "r") as fn: + sys.stderr.write('trainer 1 stderr: %s\n' % fn.read()) return pickle.loads(tr0_out), pickle.loads(tr1_out) @@ -403,13 +430,13 @@ class TestDistBase(unittest.TestCase): worker_endpoints = self._ps_endpoints.split(",") w0_ep, w1_ep = worker_endpoints - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f" tr0_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 0, w0_ep) + 0, w0_ep, self._lr / 2) tr1_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 1, w1_ep) + 1, w1_ep, self._lr / 2) if self._mem_opt: tr0_cmd += " --mem_opt" @@ -474,6 +501,7 @@ class TestDistBase(unittest.TestCase): "PYTHONPATH": os.getenv("PYTHONPATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_cudnn_deterministic": "1", "http_proxy": "", "NCCL_P2P_DISABLE": "1" diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 630bed198f4fc382d716373ea872e24b1b45bbf3..49a2ca40e3cb1dd35027345e9c38eb8b6912d2cd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -36,7 +36,7 @@ class TestDistMnistNCCL2(TestDistBase): def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): - self.check_with_place("dist_mnist.py", delta=1) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnist2x2Lars(TestDistBase): diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_dynamic_rnn_base.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_dynamic_rnn_base.py new file mode 100644 index 0000000000000000000000000000000000000000..e91cfe0b45ab7e4e56fccf8d49eb381fbbd199d1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_dynamic_rnn_base.py @@ -0,0 +1,86 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0' +os.environ['CPU_NUM'] = '2' + +import six +import unittest + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2): + if use_cuda and not core.is_compiled_with_cuda(): + print('Skip use_cuda=True because Paddle is not compiled with cuda') + return + + word_dict = paddle.dataset.imdb.word_dict() + train_reader = paddle.batch( + paddle.dataset.imdb.train(word_dict), batch_size=batch_size) + + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + cost = network(data, label, len(word_dict)) + optimizer = fluid.optimizer.Adagrad(learning_rate=0.2) + optimizer.minimize(cost) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + feeder = fluid.DataFeeder(feed_list=[data, label], place=place) + reader = feeder.decorate_reader( + train_reader, multi_devices=use_parallel_executor) + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if use_parallel_executor: + train_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, loss_name=cost.name) + fetch_list = [cost.name] + else: + train_exe = exe + fetch_list = [cost] + + for pass_id in six.moves.xrange(pass_num): + batch_id = 0 + for data in reader(): + train_exe.run(feed=data, + fetch_list=fetch_list if batch_id % 4 == 0 else []) + batch_id += 1 + if batch_id > 16: + break + + +class TestBase(unittest.TestCase): + def setUp(self): + self.net = None + + def test_network(self): + if self.net is None: + return + + for use_cuda in [True, False]: + for use_parallel_executor in [False, True]: + print('network: {}, use_cuda: {}, use_parallel_executor: {}'. + format(self.net.__name__, use_cuda, + use_parallel_executor)) + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.scope_guard(core.Scope()): + train(self.net, use_cuda, use_parallel_executor) diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_gru_net.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_gru_net.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed3d9fdf3bf765f1b9ef8ba1ef2a5795f1874c7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_gru_net.py @@ -0,0 +1,49 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_eager_deletion_dynamic_rnn_base import TestBase +import paddle.fluid as fluid + + +def gru_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=400.0): + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr)) + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3) + gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False) + gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max') + gru_max_tanh = fluid.layers.tanh(gru_max) + fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh') + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + return avg_cost + + +class GRUTest(TestBase): + def setUp(self): + self.net = gru_net + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_lstm_net.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_lstm_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8462c06aa56e0469fd06c7dc4b2ed514f7eb51ba --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_lstm_net.py @@ -0,0 +1,50 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_eager_deletion_dynamic_rnn_base import TestBase +import paddle.fluid as fluid +import unittest + + +def lstm_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2, + emb_lr=30.0): + emb = fluid.layers.embedding( + input=data, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr(learning_rate=emb_lr)) + fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4) + lstm_h, c = fluid.layers.dynamic_lstm( + input=fc0, size=hid_dim * 4, is_reverse=False) + lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max') + lstm_max_tanh = fluid.layers.tanh(lstm_max) + fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh') + prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + return avg_cost + + +class LSTMTest(TestBase): + def setUp(self): + self.net = lstm_net + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec1f0ae753724dac5c4675926ead87a097a7a99 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_mnist.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" + +from test_parallel_executor_mnist import TestMNIST + + +class EagerDeletionTestMNIST(TestMNIST): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..754d5fd40953311a5deb466fa42216f72671a65a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_deletion_transformer.py @@ -0,0 +1,27 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +os.environ['FLAGS_eager_delete_tensor_gb'] = "0.0" + +from test_parallel_executor_transformer import TestTransformer + + +class EagerDeletionTestTransformer(TestTransformer): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b6305155d1ef3dcf6ce590c221664754c5bdc8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid import core + + +class MyLayer(fluid.imperative.PyLayer): + def __init__(self): + super(MyLayer, self).__init__() + + def forward(self, inputs): + x = fluid.layers.relu(inputs[0]) + self._x_for_debug = x + return [fluid.layers.elementwise_mul(x, x)] + + +class TestImperative(unittest.TestCase): + def test_layer(self): + with fluid.imperative.guard(): + cl = core.Layer() + cl.forward([]) + l = fluid.imperative.PyLayer() + l.forward([]) + + def test_layer_in_out(self): + with fluid.imperative.guard(): + l = MyLayer() + x = l(np.array([1.0, 2.0, -1.0], dtype=np.float32))[0] + self.assertIsNotNone(x) + sys.stderr.write("%s output: %s\n" % (x, x._numpy())) + x._backward() + sys.stderr.write("grad %s\n" % l._x_for_debug._gradient()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index be51fb06a37a376f6f410336184c95981ded35dc..fb3e4da1efd32ca99f57da8f9955803ddde04f8a 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -511,6 +511,16 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(output) print(str(program)) + def test_psroi_pool(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[245, 30, 30], dtype="float32") + rois = layers.data( + name="rois", shape=[4], dtype="float32", lod_level=1) + output = layers.psroi_pool(x, rois, 5, 0.25, 7, 7) + self.assertIsNotNone(output) + print(str(program)) + def test_roi_align(self): program = Program() with program_guard(program): @@ -846,6 +856,15 @@ class TestBook(unittest.TestCase): out = layers.cross_entropy(x, label, False, 4) self.assertIsNotNone(out) + def test_bpr_loss(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[30, 10], dtype="float32") + label = layers.data(name="label", shape=[30, 1], dtype="int32") + out = layers.bpr_loss(x, label) + self.assertIsNotNone(out) + print(str(program)) + def test_expand(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py index 275e5c49d5c298a95b012582a74f8073b800991e..fa16f082880eb97f54abe8bf75e26321f72b3bd3 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py @@ -22,6 +22,15 @@ from paddle.fluid.framework import Program, program_guard from paddle.fluid.transpiler import memory_optimize +def _get_vars(prog): + assert (isinstance(prog, Program)) + all_vars = set() + for op in prog.global_block().ops: + all_vars.update(op.input_arg_names) + all_vars.update(op.output_arg_names) + return all_vars + + class TestControlFlowGraph(unittest.TestCase): def setUp(self): program = Program() @@ -37,11 +46,11 @@ class TestControlFlowGraph(unittest.TestCase): self.program = program def test_control_flow_graph(self): - print("before optimization") - print(str(self.program)) - result_program = memory_optimize(self.program) - print("after optimization") - print(str(result_program)) + result_program = self.program.clone() + memory_optimize(self.program) + old_vars = _get_vars(self.program) + new_vars = _get_vars(result_program) + self.assertTrue(old_vars != new_vars) class TestMemoryTranspiler2(unittest.TestCase): @@ -58,14 +67,22 @@ class TestMemoryTranspiler2(unittest.TestCase): avg_cost = layers.mean(cost) opt = optimizer.SGD(learning_rate=0.001) opt.minimize(avg_cost) + self.skip_set = set([cost.name, fc.name]) self.program = program def test_inplace_ops(self): - print("before optimization") - print(str(self.program)) - result_program = memory_optimize(self.program) - print("after optimization") - print(str(result_program)) + result_program = self.program.clone() + memory_optimize(self.program) + old_vars = _get_vars(self.program) + new_vars = _get_vars(result_program) + self.assertTrue(old_vars != new_vars) + + def test_skip_opt(self): + result_program = self.program.clone() + memory_optimize(self.program, skip_opt_set=self.skip_set) + old_vars = _get_vars(self.program) + new_vars = _get_vars(result_program) + self.assertTrue(old_vars != new_vars) class TestMemoryTranspiler3(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py b/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py new file mode 100644 index 0000000000000000000000000000000000000000..abe014a38c6ecfd008b0f1028536bfb49b628fb4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_psroi_pool_op.py @@ -0,0 +1,134 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import math +import numpy as np +import unittest +from op_test import OpTest + + +class TestPSROIPoolOp(OpTest): + def set_data(self): + self.init_test_case() + self.make_rois() + self.calc_psroi_pool() + self.inputs = {'X': self.x, 'ROIs': (self.rois[:, 1:5], self.rois_lod)} + self.attrs = { + 'output_channels': self.output_channels, + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width + } + self.outputs = {'Out': self.outs} + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 * 2 * 2 + self.height = 6 + self.width = 4 + + self.x_dim = [self.batch_size, self.channels, self.height, self.width] + + self.spatial_scale = 1.0 / 4.0 + self.output_channels = 3 + self.pooled_height = 2 + self.pooled_width = 2 + + self.x = np.random.random(self.x_dim).astype('float32') + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + self.rois_lod[0].append(bno + 1) + for i in range(bno + 1): + x1 = np.random.random_integers( + 0, self.width // self.spatial_scale - self.pooled_width) + y1 = np.random.random_integers( + 0, self.height // self.spatial_scale - self.pooled_height) + + x2 = np.random.random_integers(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.random_integers( + y1 + self.pooled_height, self.height // self.spatial_scale) + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + self.rois_num = len(rois) + self.rois = np.array(rois).astype('float32') + + def calc_psroi_pool(self): + output_shape = (self.rois_num, self.output_channels, self.pooled_height, + self.pooled_width) + out_data = np.zeros(output_shape) + for i in range(self.rois_num): + roi = self.rois[i] + roi_batch_id = int(roi[0]) + roi_start_w = round(roi[1]) * self.spatial_scale + roi_start_h = round(roi[2]) * self.spatial_scale + roi_end_w = (round(roi[3]) + 1.) * self.spatial_scale + roi_end_h = (round(roi[4]) + 1.) * self.spatial_scale + + roi_height = max(roi_end_h - roi_start_h, 0.1) + roi_width = max(roi_end_w - roi_start_w, 0.1) + + bin_size_h = roi_height / float(self.pooled_height) + bin_size_w = roi_width / float(self.pooled_width) + + x_i = self.x[roi_batch_id] + + for c in range(self.output_channels): + for ph in range(self.pooled_height): + for pw in range(self.pooled_width): + hstart = int( + math.floor(float(ph) * bin_size_h + roi_start_h)) + wstart = int( + math.floor(float(pw) * bin_size_w + roi_start_w)) + hend = int( + math.ceil( + float(ph + 1) * bin_size_h + roi_start_h)) + wend = int( + math.ceil( + float(pw + 1) * bin_size_w + roi_start_w)) + hstart = min(max(hstart, 0), self.height) + hend = min(max(hend, 0), self.height) + wstart = min(max(wstart, 0), self.width) + wend = min(max(wend, 0), self.width) + + c_in = (c * self.pooled_height + ph + ) * self.pooled_width + pw + is_empty = (hend <= hstart) or (wend <= wstart) + out_sum = 0. + for ih in range(hstart, hend): + for iw in range(wstart, wend): + out_sum += x_i[c_in, ih, iw] + bin_area = (hend - hstart) * (wend - wstart) + out_data[i, c, ph, pw] = 0. if is_empty else ( + out_sum / float(bin_area)) + self.outs = out_data.astype('float32') + + def setUp(self): + self.op_type = 'psroi_pool' + self.set_data() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_regularizer.py b/python/paddle/fluid/tests/unittests/test_regularizer.py index 20f91cf4485f2e79c20fe90143c8b7deebb9fc49..62994eec7e7f56267a0990d9a5e3b5c62d7d5fe4 100644 --- a/python/paddle/fluid/tests/unittests/test_regularizer.py +++ b/python/paddle/fluid/tests/unittests/test_regularizer.py @@ -15,7 +15,12 @@ from __future__ import print_function import unittest - +from functools import partial +import contextlib +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer import paddle.fluid.regularizer as regularizer @@ -97,5 +102,134 @@ class TestL1DecayRegularizer(unittest.TestCase): self.assertEqual(block.ops[-3].type, 'sign') +def bow_net(data, + label, + dict_dim, + is_sparse=False, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + fluid/PaddleNLP/text_classification/nets.py + """ + emb = fluid.layers.embedding( + input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return avg_cost + + +class TestRegularizer(unittest.TestCase): + def setUp(self): + self.word_dict = paddle.dataset.imdb.word_dict() + reader = paddle.batch( + paddle.dataset.imdb.train(self.word_dict), batch_size=8)() + self.train_data = [next(reader) for _ in range(5)] + + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + @contextlib.contextmanager + def scope_prog_guard(self, main_prog, startup_prog): + scope = fluid.core.Scope() + with fluid.unique_name.guard(): + with fluid.scope_guard(scope): + with fluid.program_guard(main_prog, startup_prog): + yield + + def run_program(self, place, feed_list): + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + exe.run(fluid.default_startup_program()) + + main_prog = fluid.default_main_program() + param_list = [var.name for var in main_prog.block(0).all_parameters()] + + param_sum = [] + for data in self.train_data: + out = exe.run(main_prog, + feed=feeder.feed(data), + fetch_list=param_list) + p_sum = 0 + for v in out: + p_sum += np.sum(np.abs(v)) + param_sum.append(p_sum) + return param_sum + + def check_l2decay_regularizer(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost = model(data, label, len(self.word_dict)) + + optimizer = fluid.optimizer.Adagrad( + learning_rate=0.1, + regularization=fluid.regularizer.L2Decay(1.0)) + optimizer.minimize(avg_cost) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def check_l2decay(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost_l2 = model(data, label, len(self.word_dict)) + + param_list = fluid.default_main_program().block(0).all_parameters() + para_sum = [] + for para in param_list: + para_mul = fluid.layers.square(x=para) + para_sum.append(fluid.layers.reduce_sum(input=para_mul)) + avg_cost_l2 += fluid.layers.sums(para_sum) * .5 + + optimizer = fluid.optimizer.Adagrad(learning_rate=0.1) + optimizer.minimize(avg_cost_l2) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def test_l2(self): + for place in self.get_places(): + dense_sparse_p_sum = [] + for sparse in [True, False]: + model = partial(bow_net, is_sparse=sparse) + framework_l2 = self.check_l2decay_regularizer(place, model) + l2 = self.check_l2decay(place, model) + assert len(l2) == len(framework_l2) + for i in range(len(l2)): + assert np.isclose(a=framework_l2[i], b=l2[i], rtol=5e-5) + dense_sparse_p_sum.append(framework_l2) + + assert len(dense_sparse_p_sum[0]) == len(dense_sparse_p_sum[1]) + for i in range(len(dense_sparse_p_sum[0])): + assert np.isclose( + a=dense_sparse_p_sum[0][i], + b=dense_sparse_p_sum[1][i], + rtol=5e-5) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py index 50204b8a77c187aa695da83860960566448d290f..f8847e1570dc47d432777faa15f4004f1a7111a6 100644 --- a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py @@ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase): # expected output selected rows expected_out0_rows = [0, 4] expected_out1_rows = [0, 2] + expected_out2_rows = [] expected_out4_rows = [0] op = Operator( @@ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertEqual(outs[0].rows(), expected_out0_rows) self.assertEqual(outs[1].rows(), expected_out1_rows) + self.assertEqual(outs[2].rows(), expected_out2_rows) self.assertEqual(outs[4].rows(), expected_out4_rows) self.assertEqual(outs[0].height(), height_sections[0]) @@ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1]) self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1]) + self.assertEqual(outs[2].numel(), 0) + self.assertEqual(outs[3].numel(), 0) + def check_grad_with_place(self, place): scope = core.Scope() height = 10 diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index c9f1be934773cc28f026f2b867b9e3a4f7aa8472..95aafec05361a8b66b849268c7a738bb2ee5da86 100755 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -14,6 +14,7 @@ from __future__ import print_function +import six from collections import defaultdict, MutableSet from .. import core from ... import compat as cpt @@ -470,8 +471,21 @@ def memory_optimize(input_program, Returns: None """ + + def to_name_str(var): + if isinstance(var, Variable): + return var.desc.name() + elif isinstance(var, str): + return var + elif isinstance(var, six.string_types): + return str(var) + else: + raise TypeError(str(var) + " should be Variable or str") + if level != 0 and level != 1: raise ValueError("only support opt_level 0 or 1.") + if skip_opt_set is not None and not isinstance(skip_opt_set, set): + raise ValueError("only support skip_opt_set as set.") global PRINT_LOG PRINT_LOG = print_log if skip_grads: @@ -486,6 +500,8 @@ def memory_optimize(input_program, skip_opt_set = grad_set else: skip_opt_set.update(grad_set) + if skip_opt_set is not None: + skip_opt_set = set(map(to_name_str, skip_opt_set)) cfgs = _get_cfgs(input_program) for cfg in cfgs: cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level) diff --git a/python/setup.py.in b/python/setup.py.in index 5aee26b63832889272cde09c553b4615efb8872a..0eb69cdb5c7d140527dba7a648728750bfb404f7 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -101,6 +101,7 @@ packages=['paddle', 'paddle.dataset', 'paddle.reader', 'paddle.fluid', + 'paddle.fluid.imperative', 'paddle.fluid.proto', 'paddle.fluid.proto.profiler', 'paddle.fluid.layers', diff --git a/tools/print_signatures.py b/tools/print_signatures.py index 5c5266f904f5dcf74dd1d4ee7e98081f74a79907..7e61dde0a446cf5bfe656105ffd2472f03576f05 100644 --- a/tools/print_signatures.py +++ b/tools/print_signatures.py @@ -27,6 +27,8 @@ import pydoc member_dict = collections.OrderedDict() +experimental_namespace = {"paddle.fluid.imperative"} + def visit_member(parent_name, member): cur_name = ".".join([parent_name, member.__name__]) @@ -51,6 +53,8 @@ def visit_member(parent_name, member): def visit_all_module(mod): + if (mod.__name__ in experimental_namespace): + return for member_name in ( name for name in (mod.__all__ if hasattr(mod, "__all__") else dir(mod))