diff --git a/CMakeLists.txt b/CMakeLists.txt index efa68c9ba243af3c7cdca52b915cc14d307ae89f..1594e798a2ba3f735a28a43ef933d80b3b3f8964 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,7 +54,7 @@ option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) option(WITH_DOUBLE "Compile PaddlePaddle with double precision" OFF) option(WITH_RDMA "Compile PaddlePaddle with RDMA support" OFF) option(WITH_TIMER "Compile PaddlePaddle with stats timer" OFF) -option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler" OFF) +option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF) option(WITH_DOC "Compile PaddlePaddle with documentation" OFF) option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF) option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) @@ -254,6 +254,12 @@ elseif() set(WITH_ANAKIN OFF CACHE STRING "Anakin is used in MKL only now." FORCE) endif() +if (WITH_PROFILER) + find_package(Gperftools REQUIRED) + include_directories(${GPERFTOOLS_INCLUDE_DIR}) + add_definitions(-DWITH_GPERFTOOLS) +endif() + include(generic) # simplify cmake module include(package) # set paddle packages include(ccache) # set ccache for compilation diff --git a/cmake/FindGperftools.cmake b/cmake/FindGperftools.cmake new file mode 100644 index 0000000000000000000000000000000000000000..928f573a4fb82391859e334d50e6c8ed0e26aae2 --- /dev/null +++ b/cmake/FindGperftools.cmake @@ -0,0 +1,63 @@ +# Tries to find Gperftools. +# +# Usage of this module as follows: +# +# find_package(Gperftools) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# Gperftools_ROOT_DIR Set this variable to the root installation of +# Gperftools if the module has problems finding +# the proper installation path. +# +# Variables defined by this module: +# +# GPERFTOOLS_FOUND System has Gperftools libs/headers +# GPERFTOOLS_LIBRARIES The Gperftools libraries (tcmalloc & profiler) +# GPERFTOOLS_INCLUDE_DIR The location of Gperftools headers + +find_library(GPERFTOOLS_TCMALLOC + NAMES tcmalloc + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_PROFILER + NAMES profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_library(GPERFTOOLS_TCMALLOC_AND_PROFILER + NAMES tcmalloc_and_profiler + HINTS ${Gperftools_ROOT_DIR}/lib) + +find_path(GPERFTOOLS_INCLUDE_DIR + NAMES gperftools/heap-profiler.h + HINTS ${Gperftools_ROOT_DIR}/include) + +set(GPERFTOOLS_LIBRARIES ${GPERFTOOLS_TCMALLOC_AND_PROFILER}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + Gperftools + DEFAULT_MSG + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +mark_as_advanced( + Gperftools_ROOT_DIR + GPERFTOOLS_TCMALLOC + GPERFTOOLS_PROFILER + GPERFTOOLS_TCMALLOC_AND_PROFILER + GPERFTOOLS_LIBRARIES + GPERFTOOLS_INCLUDE_DIR) + +# create IMPORTED targets +if (Gperftools_FOUND AND NOT TARGET gperftools::tcmalloc) + add_library(gperftools::tcmalloc UNKNOWN IMPORTED) + set_target_properties(gperftools::tcmalloc PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_TCMALLOC} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") + add_library(gperftools::profiler UNKNOWN IMPORTED) + set_target_properties(gperftools::profiler PROPERTIES + IMPORTED_LOCATION ${GPERFTOOLS_PROFILER} + INTERFACE_INCLUDE_DIRECTORIES "${GPERFTOOLS_INCLUDE_DIR}") +endif() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 4e17ddee73958106d5e2c8c8ea5661acc758518a..51f7a61631d7102b60646abe1c6dd7775692f157 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -86,6 +86,7 @@ endif(NOT WITH_GOLANG) if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) + add_definitions(-DEIGEN_USE_GPU) FIND_PACKAGE(CUDA REQUIRED) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 312fbaa0b3d83c37debe78be82503103eabc0bfa..a8b9dcfcf5eec39af0f59c03b1ed9bd4b71ee7bf 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -110,6 +110,14 @@ function(find_fluid_modules TARGET_NAME) endif() endfunction(find_fluid_modules) + +function(common_link TARGET_NAME) + if (WITH_PROFILER) + target_link_libraries(${TARGET_NAME} gperftools::profiler) + endif() +endfunction() + + # find all third_party modules is used for paddle static library # for reduce the dependency when building the inference libs. set_property(GLOBAL PROPERTY FLUID_THIRD_PARTY) @@ -274,6 +282,7 @@ function(cc_library TARGET_NAME) endif() target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) + common_link(${TARGET_NAME}) endif() # cpplint code style @@ -340,6 +349,7 @@ function(cc_binary TARGET_NAME) if(cc_binary_DEPS) target_link_libraries(${TARGET_NAME} ${cc_binary_DEPS}) add_dependencies(${TARGET_NAME} ${cc_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endfunction(cc_binary) @@ -362,6 +372,7 @@ function(cc_test TARGET_NAME) target_link_libraries(${TARGET_NAME} ${win32_deps}) endif(WIN32) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} ${cc_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) @@ -420,6 +431,7 @@ function(nv_binary TARGET_NAME) if(nv_binary_DEPS) target_link_libraries(${TARGET_NAME} ${nv_binary_DEPS}) add_dependencies(${TARGET_NAME} ${nv_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endif() endfunction(nv_binary) @@ -433,6 +445,7 @@ function(nv_test TARGET_NAME) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) @@ -499,6 +512,7 @@ function(hip_binary TARGET_NAME) if(hip_binary_DEPS) target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS}) add_dependencies(${TARGET_NAME} ${hip_binary_DEPS}) + common_link(${TARGET_NAME}) endif() endif() endfunction(hip_binary) @@ -518,6 +532,7 @@ function(hip_test TARGET_NAME) set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main memory gtest gflags) + common_link(${TARGET_NAME}) add_test(${TARGET_NAME} ${TARGET_NAME}) endif() endfunction(hip_test) @@ -560,6 +575,7 @@ function(go_library TARGET_NAME) endif() if(go_library_DEPS) add_dependencies(${TARGET_NAME} ${go_library_DEPS}) + common_link(${TARGET_NAME}) endif(go_library_DEPS) # The "source file" of the library is `${dummyfile}` which never diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index e4c471d86b7bff1bfb3b697ab24219144b4667f5..ce429fefa77b81dff9bf997ba092e92d97cb0dc0 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -129,11 +129,13 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -if(NOT WIN32) -cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) -cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog - shape_inference data_transform lod_tensor profiler) -endif(NOT WIN32) +if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) + cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog + shape_inference data_transform lod_tensor profiler ngraph) + endif(NOT WIN32) +endif(WITH_NGRAPH) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) @@ -169,11 +171,15 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - if(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) - else(NOT WIN32) + if(WITH_NGRAPH) + if(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph ngraph_operator variable_helper) + else(NOT WIN32) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) + endif(NOT WIN32) + else(WITH_NGRAPH) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) - endif(NOT WIN32) + endif(WITH_NGRAPH) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 73cec21e20f2fd26e144872f1f7b5bb7065adb74..e97cf44c75cfdc2e7df22aa870916866b18b3b5a 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" -#include "paddle/fluid/framework/ngraph_operator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/transfer_scope_cache.h" @@ -26,6 +25,10 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_NGRAPH +#include "paddle/fluid/framework/ngraph_operator.h" +#endif + DECLARE_bool(benchmark); DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DEFINE_bool(use_ngraph, false, "Use NGRAPH to run"); @@ -88,11 +91,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op, static void EnableFusedOp(ExecutorPrepareContext* ctx) { #ifdef PADDLE_WITH_NGRAPH VLOG(3) << "use_ngraph=True"; - auto intervals = FusedOperator::FusedOpIntervals(&ctx->ops_); + auto intervals = NgraphOperator::NgraphOpIntervals(&ctx->ops_); for (auto& interval : intervals) { - auto* fused_op = new FusedOperator(ctx->prog_, ctx->block_id_, - interval.at(0), interval.at(1)); - *interval[0] = std::unique_ptr(fused_op); + auto* ng_op = new NgraphOperator(ctx->prog_, ctx->block_id_, interval.at(0), + interval.at(1)); + *interval[0] = std::unique_ptr(ng_op); } for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { ctx->ops_.erase(it->at(0) + 1, it->at(1)); diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index e22c29037718a60ff7f24404d7749600e2edb80b..a5acfd70449e92663cb66ef90a141c087ff6ec88 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include #include @@ -27,14 +26,15 @@ namespace paddle { namespace framework { static std::shared_ptr GetNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, const VariableNameMap& var_map, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = var_map.at(prm); + auto& var_names = var_map.at(name); PADDLE_ENFORCE_EQ(var_names.size(), 1, - "op %s prm %s expects one associated var", op->Type(), prm); + "op %s name %s expects one associated var", op->Type(), + name); if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { return (*ngb_node_map)[var_names[0]]; } else { @@ -43,42 +43,42 @@ static std::shared_ptr GetNode( } static std::shared_ptr GetInputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Inputs(), ngb_node_map); + return GetNode(op, name, op->Inputs(), ngb_node_map); } static std::shared_ptr GetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr< std::unordered_map>> ngb_node_map) { - return GetNode(op, prm, op->Outputs(), ngb_node_map); + return GetNode(op, name, op->Outputs(), ngb_node_map); } static void SetOutputNode( - const std::shared_ptr& op, const std::string prm, + const std::shared_ptr& op, const std::string name, std::shared_ptr node, std::shared_ptr< std::unordered_map>> ngb_node_map) { - auto& var_names = op->Outputs().at(prm); + auto& var_names = op->Outputs().at(name); if (var_names.size() == 1) { (*ngb_node_map)[var_names[0]] = node; } else if (var_names.size() == 0) { (*ngb_node_map)[""] = node; } else { - PADDLE_THROW("prm %s has more than 1 var_names.", prm); + PADDLE_THROW("name %s has more than 1 var_names.", name); } } static bool HasOutput(const std::shared_ptr& op, - const std::string prm) { + const std::string name) { auto& outputs = op->Outputs(); - if (outputs.find(prm) == outputs.end()) return false; - return outputs.at(prm).size() > 0; + if (outputs.find(name) == outputs.end()) return false; + return outputs.at(name).size() > 0; } template @@ -118,4 +118,3 @@ void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h index 9ed6b9510942136a61faa5755fd8fa74286939a8..5ad7b8daeb6a782515e50fc87ca7188b46308390 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -53,4 +51,3 @@ class NgraphBridge { } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 3fea753f0659019395c9b214e52a7912058c501c..253de4c61160e52202a0192215a93284f27e5896 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #include #include @@ -58,16 +57,16 @@ typedef enum { /* nGraph support state on ops */ } op_state; // perform graph build through bridge and execute computation -class NgraphOperator { +class NgraphEngine { public: - explicit NgraphOperator(const Scope& scope, const platform::Place& place, - const std::vector>& ops, - const std::unordered_map< - std::string, ngraph::element::Type>& var_type_map, - const std::unordered_set& persist, - const std::unordered_set& fetches, - const std::unordered_set& post_op_inputs, - op_state ng_op_state) + explicit NgraphEngine(const Scope& scope, const platform::Place& place, + const std::vector>& ops, + const std::unordered_map< + std::string, ngraph::element::Type>& var_type_map, + const std::unordered_set& persist, + const std::unordered_set& fetches, + const std::unordered_set& post_op_inputs, + op_state ng_op_state) : scope_(scope), place_(place), fused_ops_(ops), @@ -132,7 +131,7 @@ class NgraphOperator { }; std::vector>::iterator>> -FusedOperator::FusedOpIntervals( +NgraphOperator::NgraphOpIntervals( std::vector>* ops) { std::vector>::iterator>> intervals; @@ -185,7 +184,7 @@ FusedOperator::FusedOpIntervals( return intervals; } -FusedOperator::FusedOperator( +NgraphOperator::NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -215,7 +214,7 @@ FusedOperator::FusedOperator( Process(); } -void FusedOperator::Process() { +void NgraphOperator::Process() { auto& bdesc = pdesc_.Block(block_); for (auto& var : bdesc.AllVars()) { if (!(var->GetType() == proto::VarType::SELECTED_ROWS || @@ -251,8 +250,8 @@ void FusedOperator::Process() { } } -void FusedOperator::RunImpl(const Scope& scope, - const platform::Place& place) const { +void NgraphOperator::RunImpl(const Scope& scope, + const platform::Place& place) const { op_state ng_op_state = PARTIAL_TEST; auto& bdesc = pdesc_.Block(block_); for (auto* op : bdesc.AllOps()) { @@ -266,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope, ng_op_state = ng_op_state == PARTIAL_TEST ? FULL_TEST : FULL_TRAIN; } - NgraphOperator ngraph_op(scope, place, fused_ops_, var_type_map_, - persistables_, fetches_, post_op_inputs_, - ng_op_state); - ngraph_op.Run(scope, place); + NgraphEngine ngraph_engine(scope, place, fused_ops_, var_type_map_, + persistables_, fetches_, post_op_inputs_, + ng_op_state); + ngraph_engine.Run(scope, place); } std::unordered_map> - NgraphOperator::func_cache_ = {}; + NgraphEngine::func_cache_ = {}; -std::shared_ptr NgraphOperator::backend_ = +std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); -void NgraphOperator::GetNgInputShape(std::shared_ptr op) { +void NgraphEngine::GetNgInputShape(std::shared_ptr op) { op->RuntimeInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { @@ -301,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr op) { } } -void NgraphOperator::BuildNgNodes() { +void NgraphEngine::BuildNgNodes() { for (auto& var_name : var_out_) { if (var_node_map_->find(var_name) == var_node_map_->end()) { auto* var = scope_.FindVar(var_name); @@ -323,7 +322,7 @@ void NgraphOperator::BuildNgNodes() { } } -void NgraphOperator::BuildNgIO() { +void NgraphEngine::BuildNgIO() { std::unordered_set inputs; std::unordered_set outputs; @@ -395,7 +394,7 @@ void NgraphOperator::BuildNgIO() { } } -void NgraphOperator::BuildNgFunction() { +void NgraphEngine::BuildNgFunction() { BuildNgNodes(); ngraph_function_ = nullptr; ngraph::NodeVector func_outputs; @@ -416,7 +415,7 @@ void NgraphOperator::BuildNgFunction() { std::make_shared(func_outputs, func_inputs); } -std::shared_ptr NgraphOperator::GetCacheKey() { +std::shared_ptr NgraphEngine::GetCacheKey() { auto cache_key = std::make_shared(""); *cache_key += std::to_string(fused_ops_.size()); for (auto& op : fused_ops_) { @@ -444,7 +443,7 @@ std::shared_ptr NgraphOperator::GetCacheKey() { return cache_key; } -void NgraphOperator::GetNgFunction() { +void NgraphEngine::GetNgFunction() { bool cache_on = true; if (cache_on) { std::string cache_key_val = *GetCacheKey(); @@ -459,8 +458,7 @@ void NgraphOperator::GetNgFunction() { } } -void NgraphOperator::Run(const Scope& scope, - const platform::Place& place) const { +void NgraphEngine::Run(const Scope& scope, const platform::Place& place) const { std::vector> t_in; std::vector> t_out; @@ -545,7 +543,6 @@ void NgraphOperator::Run(const Scope& scope, } backend_->call(ngraph_function_, t_out, t_in); -} // NgraphOperator::RunImpl +} // NgraphEngine::RunImpl } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/ngraph_operator.h b/paddle/fluid/framework/ngraph_operator.h index 3ca023e11111c5b447b2cabbfb8bb29877297f65..ede80f44bea208b66acc3b3f4bc0f4adee4fb860 100644 --- a/paddle/fluid/framework/ngraph_operator.h +++ b/paddle/fluid/framework/ngraph_operator.h @@ -14,8 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_WITH_NGRAPH - #include #include #include @@ -34,14 +32,14 @@ limitations under the License. */ namespace paddle { namespace framework { -class FusedOperator : public OperatorBase { +class NgraphOperator : public OperatorBase { public: static std::vector< std::vector>::iterator>> - FusedOpIntervals( + NgraphOpIntervals( std::vector>* ops); - explicit FusedOperator( + explicit NgraphOperator( const ProgramDesc& prog, size_t block_id, std::vector>::iterator start, std::vector>::iterator end, @@ -64,4 +62,3 @@ class FusedOperator : public OperatorBase { }; } // namespace framework } // namespace paddle -#endif diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 36673e48c2047bca54f604b082dfec123f1e2c82..6d39bb3c524b4725dfebd6ef07594b0b45c65463 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -319,7 +319,7 @@ struct OpKernelRegistrarFunctorEx &places) - : places_(places) {} + : places_(places) { + if (!FLAGS_pe_profile_fname.empty()) { + std::call_once(gProfileOnce, [] { +#ifdef WITH_GPERFTOOLS + ProfilerStart(FLAGS_pe_profile_fname.c_str()); + gProfileStarted = true; +#else + LOG(WARNING) << "Paddle is not compiled with gperftools. " + "FLAGS_pe_profile_fname will be ignored"; +#endif + }); + } + } ~ParallelExecutorPrivate() { if (own_local_scope_) { @@ -270,6 +293,12 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { +#ifdef WITH_GPERFTOOLS + if (gProfileStarted) { + ProfilerFlush(); + } +#endif + platform::RecordBlock b(0); #ifdef PADDLE_WITH_CUDA if (!gcs_.empty()) { diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index b8a045c18fab54581b4d2b902be373f55ad09e8a..c6e923c00484f01f17550ae2926dabcadc0c3ac6 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -44,9 +44,10 @@ void IrGraphBuildPass::RunImpl(Argument *argument) { argument->SetMainProgram(program.release()); } else if (argument->model_program_path_valid() && argument->model_params_path_valid()) { - auto program = - LoadModel(argument->model_program_path(), argument->model_params_path(), - argument->scope_ptr(), place, argument->model_from_memory()); + auto program = LoadModel( + argument->model_program_path(), argument->model_params_path(), + argument->scope_ptr(), place, + argument->model_from_memory_valid() && argument->model_from_memory()); argument->SetMainProgram(program.release()); } else { PADDLE_THROW( diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index a07626a10315a6206f8c1ebc9a19df90663a88ee..8a4bc04b67879918c6ac8d1b40dae68a107034d4 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -1,4 +1,4 @@ -set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor) +set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor benchmark) if(WITH_GPU AND TENSORRT_FOUND) set(INFERENCE_EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor) diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index d572ea0177c1e398229a02718ca31cc78a7059ef..8209a049f4614fe31c22c4e83c1968411b749b49 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -30,8 +30,10 @@ #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/tests/api/config_printer.h" #include "paddle/fluid/inference/tests/test_helper.h" +#include "paddle/fluid/inference/utils/benchmark.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_string(model_name, "", "model name"); DEFINE_string(infer_model, "", "model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_int32(batch_size, 1, "batch size."); @@ -40,6 +42,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DEFINE_bool(record_benchmark, false, + "Record benchmark after profiling the model"); DECLARE_bool(profile); DECLARE_int32(paddle_num_threads); @@ -192,8 +196,16 @@ void TestOneThreadPrediction( predictor->Run(inputs[j], outputs, batch_size); } } - PrintTime(batch_size, num_times, 1, 0, run_timer.toc() / num_times, - inputs.size()); + + double latency = run_timer.toc() / num_times; + PrintTime(batch_size, num_times, 1, 0, latency, inputs.size()); + if (FLAGS_record_benchmark) { + Benchmark benchmark; + benchmark.SetName(FLAGS_model_name); + benchmark.SetBatchSize(batch_size); + benchmark.SetLatency(latency); + benchmark.PersistToFile("benchmark_record.txt"); + } } } diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index ef612ce6148329c33f194842945bb5438afcf645..9eb3fb5da1065f14d9ad1c3520f6415fbadfdca1 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -135,6 +135,9 @@ TEST(TensorRT_resnext50, compare) { TEST(TensorRT_resnext50, profile) { std::string model_dir = FLAGS_infer_model + "/resnext50"; + // Set FLAGS_record_benchmark to true to record benchmark to file. + // FLAGS_record_benchmark=true; + FLAGS_model_name = "resnext50"; profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt); } diff --git a/paddle/fluid/inference/utils/benchmark.cc b/paddle/fluid/inference/utils/benchmark.cc index d03aa11b75ee58524746212e43a5796773f47932..0bd526bcac2d9ceda95730dc3c5210aed8ccfb5c 100644 --- a/paddle/fluid/inference/utils/benchmark.cc +++ b/paddle/fluid/inference/utils/benchmark.cc @@ -30,7 +30,7 @@ std::string Benchmark::SerializeToString() const { ss << '\n'; ss << name_ << "\t"; - ss << batch_size_ << "\t"; + ss << batch_size_ << "\t\t"; ss << num_threads_ << "\t"; ss << latency_ << "\t"; ss << 1000.0 / latency_; diff --git a/paddle/fluid/inference/utils/visualizer.cc b/paddle/fluid/inference/utils/visualizer.cc index 040b6476fb4febc5ca1912c8db72dc63c3bced08..7c0dd64dea88e51b24c4bc04818d633ee0d2f722 100644 --- a/paddle/fluid/inference/utils/visualizer.cc +++ b/paddle/fluid/inference/utils/visualizer.cc @@ -26,9 +26,6 @@ DEFINE_string(model_dir, "", "model directory"); DEFINE_string(model_program_path, "", "model program path"); DEFINE_string(model_params_path, "", "model params path"); -USE_PASS(graph_viz_pass); -USE_PASS(graph_to_program_pass); - using paddle::inference::analysis::Argument; namespace paddle { @@ -40,7 +37,6 @@ void Visualizer::SetArgument(Argument *argument) { argument_ = argument; } bool Visualizer::Run() { paddle::framework::InitDevices(false); paddle::inference::analysis::Analyzer().Run(argument_); - return true; } @@ -77,7 +73,7 @@ int main(int argc, char *argv[]) { // Only 1 pass, default filename is 0_ir_origin.dot // For more details, looking for paddle::inference::analysis::IRPassManager - argument.SetIrAnalysisPasses({"graph_viz_pass"}); + argument.SetIrAnalysisPasses({"infer_clean_graph_pass", "graph_viz_pass"}); std::unique_ptr scope{ new paddle::framework::Scope()}; @@ -90,3 +86,7 @@ int main(int argc, char *argv[]) { return 0; } + +USE_PASS(infer_clean_graph_pass); +USE_PASS(graph_viz_pass); +USE_PASS(graph_to_program_pass); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 87d549678a0e6c183aac89539cf1f6331729de2c..c7df3ea58a91579e35ff0d486516271a6daf054f 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -301,23 +301,22 @@ template struct GeluFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out) const { - auto temp = - ((x * static_cast(M_SQRT1_2)).erf()).template cast().eval(); + auto temp = (x * static_cast(M_SQRT1_2)).erf(); out.device(d) = x * static_cast(0.5) * (static_cast(1) + temp); } }; template struct GeluGradFunctor : BaseActivationFunctor { - bool Inplace() const { return IsInplace("gelu"); } template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp = (static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * - ((-static_cast(0.5) * x.square()).exp())) - .template cast() - .eval(); - dx.device(d) = dout * (out / x + temp); + auto first = static_cast(0.5) * + (static_cast(1) + ((x * static_cast(M_SQRT1_2)).erf())); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * + (-static_cast(0.5) * x.square()).exp(); + dx.device(d) = dout * (first + second); } }; diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.cu b/paddle/fluid/operators/bilinear_tensor_product_op.cu index 9426ffbe174c7daf9f24525f5f7ca12d986042f4..c2b4f69e6854522b91dfd9fb5f738c0e5ffc77b1 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.cu +++ b/paddle/fluid/operators/bilinear_tensor_product_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/fluid/operators/bilinear_tensor_product_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/concat_mkldnn_op.cc b/paddle/fluid/operators/concat_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ad674056f0d753d79408a11eff1aca47a84998a --- /dev/null +++ b/paddle/fluid/operators/concat_mkldnn_op.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/concat_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::Tensor; +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::concat; +using mkldnn::stream; +using platform::to_void_cast; + +static void EnforceLayouts(const std::vector inputs) { + for (auto* input : inputs) { + const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; + const bool is_format_defined = + input->format() != memory::format::format_undef; + PADDLE_ENFORCE(is_layout_correct && is_format_defined, + "Wrong layout/format set for Input tensor"); + } +} + +static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, + const mkldnn::engine& engine) { + constexpr auto data_type = mkldnn::memory::f32; + const auto dims = paddle::framework::vectorize2int(input.dims()); + const auto format = input.format(); + auto description = memory::desc(dims, data_type, format); + auto mem_prim_desc = memory::primitive_desc(description, engine); + return mem_prim_desc; +} + +static mkldnn::memory::format GetDstMemFormat( + const concat::primitive_desc& concat_pd) { + return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; +} + +static platform::CPUPlace GetCpuPlace( + const paddle::framework::ExecutionContext& ctx) { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(paddle::platform::is_cpu_place(place), + "It must use CPUPlace."); + return boost::get(place); +} + +static const mkldnn::engine& GetMKLDNNEngine( + const paddle::framework::ExecutionContext& ctx) { + auto& dev_ctx = ctx.template device_context(); + return dev_ctx.GetEngine(); +} + +template +class ConcatPrimitiveFactory { + public: + concat::primitive_desc CreateConcatPrimDescriptor( + const std::vector multi_input, Tensor* output, + int concat_axis, const mkldnn::engine& mkldnn_engine) { + CreateSourcesDescriptors(multi_input, mkldnn_engine); + auto dst_desc = CreateDstMemDescriptor(output); + return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); + } + + concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + CreateSourcePrimitiveAts(); + dst_mem = CreateDstMemory(concat_pd, output, place); + return concat(concat_pd, inputs, dst_mem.get()); + } + + private: + memory::desc CreateDstMemDescriptor(Tensor* output) { + auto dst_dims = paddle::framework::vectorize2int(output->dims()); + return memory::desc(dst_dims, platform::MKLDNNGetDataType(), + memory::format::any); + } + + mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, + Tensor* output, platform::CPUPlace place) { + return memory(concat_pd.dst_primitive_desc(), + output->mutable_data(place)); + } + + void CreateSourcesDescriptors(const std::vector multi_input, + const mkldnn::engine& mkldnn_engine) { + for (size_t i = 0; i < multi_input.size(); i++) { + auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); + srcs_pd.push_back(mem_prim_desc); + srcs.push_back( + memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); + } + } + + void CreateSourcePrimitiveAts() { + inputs.reserve(srcs.size()); + for (size_t i = 0; i < srcs.size(); i++) { + inputs.push_back(srcs[i]); + } + } + + private: + std::vector srcs_pd; + std::vector srcs; + std::vector inputs; + boost::optional dst_mem; // TODO(mgallus): change to std::optional +}; // upon introduction of C++17 to paddle + +template +class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto place = GetCpuPlace(ctx); + const auto& mkldnn_engine = GetMKLDNNEngine(ctx); + + auto multi_input = ctx.MultiInput("X"); + EnforceLayouts(multi_input); + Tensor* output = ctx.Output("Out"); + int64_t concat_axis = static_cast(ctx.Attr("axis")); + + ConcatPrimitiveFactory prim_creator; + auto concat_pd = prim_creator.CreateConcatPrimDescriptor( + multi_input, output, static_cast(concat_axis), mkldnn_engine); + auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); + stream(stream::kind::eager).submit({concat}).wait(); + + output->set_layout(DataLayout::kMKLDNN); + output->set_format(GetDstMemFormat(concat_pd)); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace, + ops::ConcatMKLDNNOpKernel) diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 57817da71adfd80faad29a48b05ba2f326de6c07..194f9cf5033a3a73afeb8e92ddbdcc7b316fcd35 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/concat_op.h" - #include #include +#ifdef PADDLE_WITH_MKLDNN +#include +#endif + namespace paddle { namespace operators { using framework::Tensor; @@ -59,6 +62,22 @@ class ConcatOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", out_dims); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { @@ -66,6 +85,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Input tensors of concat operator.").AsDuplicable(); AddOutput("Out", "Output tensor of concat operator."); + AddAttr( + "use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); AddAttr("axis", "The axis along which the input tensors will be concatenated.") .SetDefault(0); diff --git a/paddle/fluid/operators/cos_sim_op.cu b/paddle/fluid/operators/cos_sim_op.cu index 82205e9c75402e368a2d1e161d471e35ff7356ea..3d144ca29d9989ad2cbb438a950860eaac873d07 100644 --- a/paddle/fluid/operators/cos_sim_op.cu +++ b/paddle/fluid/operators/cos_sim_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/cos_sim_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/crop_op.cu b/paddle/fluid/operators/crop_op.cu index b75678217e36aa2297c68a7f8e2a9dfafadaca72..66cb5c452de4b2107693127ce414daf9fb7cd7d8 100644 --- a/paddle/fluid/operators/crop_op.cu +++ b/paddle/fluid/operators/crop_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/crop_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/distributed/brpc_client.cc b/paddle/fluid/operators/distributed/brpc_client.cc index b394c678fb6503eb73a1e11e6feb814251e9e940..350969f74be258ffbfef687b56083a9c6508bc81 100644 --- a/paddle/fluid/operators/distributed/brpc_client.cc +++ b/paddle/fluid/operators/distributed/brpc_client.cc @@ -158,7 +158,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) { for (int i = 0; i < FLAGS_brpc_channel_num; ++i) { std::shared_ptr c(new ChannelContext()); if (c->channel.Init(ep.c_str(), &options) != 0) { - LOG(ERROR) << "Fail to initialize channel"; + LOG(FATAL) << "Fail to initialize channel"; return nullptr; } diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc index 857214aa211aee0251571e46049c66c084b470f1..f14dfcdb238a9580affde96e4d5a0093743eb6c8 100644 --- a/paddle/fluid/operators/distributed/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc_client.cc @@ -390,8 +390,7 @@ void GRPCClient::Proceed() { VLOG(3) << c->GetVarHandlePtr()->String() << " process"; c->Process(); } else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) { - // FIXME(gongwb): parse error_details? - LOG(ERROR) << c->GetVarHandlePtr()->String() + LOG(FATAL) << c->GetVarHandlePtr()->String() << " meets grpc error, error_code:" << c->status_.error_code() << " error_message:" << c->status_.error_message() << " error_details:" << c->status_.error_details(); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index e011f47e086183a4ef3a3373c17acd6c21b6cf7e..d65491267de1ce3495d8b8250cf0cff570dfcc6a 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include #include #include diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 2fb7eeb4b9e3119a6eea3e69a2a6002a80f6c0f3..fed12785f47e1b8eea3f053712830901bee3bdc9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index c5a1a7e08d89f3ef205af4c37246f8fa288189f3..1a149298fd33f132a90ff5de3b35dd5894a4ae68 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index a90dcd3ecf0da114110db5946e111a8b3a925e42..5d086a1b29febd8e57507eced7683f414ca34e07 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index ab77709c28c15a925bd3deac07c43e12b12cb781..cf93e5a97a3f3110aae907c593f58dbab0f9d090 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 4d16bc38e1d8e4cbbe3afbe08f233e14329e0f2e..833c4072826c58277bc23e03b787fafbbaa73d03 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu index 6ee0779f23bc2c734aa1d439abb12f366227e686..9263dbfebfd00451f3e67c3ca9d2081b5b4904bd 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu @@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 8d9bf7c4d81d49d83b5d1cf0369be5c9957242b4..6f17d3292f307b009c640738109d5a4f4ca4caa9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/expand_op.cu b/paddle/fluid/operators/expand_op.cu index 60363bfc86d7d1a79d7b018cee43a41c1247a994..d95c9b61802b5fe7059e1c95a50776db5aa7ad93 100644 --- a/paddle/fluid/operators/expand_op.cu +++ b/paddle/fluid/operators/expand_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/expand_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/gru_unit_op.cu b/paddle/fluid/operators/gru_unit_op.cu index fc92b3d4a7a5a933f31b21d18238de386b3afb4d..37689901ecbeeda44f52a2fc7a686f4edf6682bb 100644 --- a/paddle/fluid/operators/gru_unit_op.cu +++ b/paddle/fluid/operators/gru_unit_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/gru_unit_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/hinge_loss_op.cu b/paddle/fluid/operators/hinge_loss_op.cu index 9c0a85bee6e28865225c1848ea5a378f48932ceb..b5ea0a702e0e540c1831ca241a5def19f86c239c 100644 --- a/paddle/fluid/operators/hinge_loss_op.cu +++ b/paddle/fluid/operators/hinge_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/hinge_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/huber_loss_op.cu b/paddle/fluid/operators/huber_loss_op.cu index 659464df9dc0e7c8cd276bd0bbf7072361aa3abf..09c743c4275169ba8c53ccbd428100b2fc4483d6 100644 --- a/paddle/fluid/operators/huber_loss_op.cu +++ b/paddle/fluid/operators/huber_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/huber_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/im2sequence_op.cu b/paddle/fluid/operators/im2sequence_op.cu index e0a5a90c1c3c47ea45b3f83ae969c1861783ff60..1c34640618d58d3b5fe627fa6596260a7b687d05 100644 --- a/paddle/fluid/operators/im2sequence_op.cu +++ b/paddle/fluid/operators/im2sequence_op.cu @@ -11,8 +11,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/im2sequence_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/isfinite_op.cu b/paddle/fluid/operators/isfinite_op.cu index 8d1268b18c6fec03063051f545075209a6fcde27..995969cd42f08c7fa948262e42793106e745b3a7 100644 --- a/paddle/fluid/operators/isfinite_op.cu +++ b/paddle/fluid/operators/isfinite_op.cu @@ -11,8 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/l1_norm_op.cu b/paddle/fluid/operators/l1_norm_op.cu index 1b48571dd7378c1c2a6628662024bc7bcc08d2a6..a5c29bbf5debdd11f6e5b28b3a8b48c2c484517a 100644 --- a/paddle/fluid/operators/l1_norm_op.cu +++ b/paddle/fluid/operators/l1_norm_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/l1_norm_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/log_loss_op.cu b/paddle/fluid/operators/log_loss_op.cu index e8bf7d8159bf8b16bf4397e7765918c060124db3..280913c43a2749ddd5fbd3ae1905f1b823dd525d 100644 --- a/paddle/fluid/operators/log_loss_op.cu +++ b/paddle/fluid/operators/log_loss_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/log_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/math/context_project.cu b/paddle/fluid/operators/math/context_project.cu index 16205c0e145ef70666d4eca564488d80bde26d2e..f04b2d15349be329ee228fc8903c9b38a5349634 100644 --- a/paddle/fluid/operators/math/context_project.cu +++ b/paddle/fluid/operators/math/context_project.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/math/context_project.h" namespace paddle { diff --git a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc index fead13ebadcd131afafc308740cdd39b1c53bc08..cb49e66488bd69d92430cbf6de1d08348ffe0202 100644 --- a/paddle/fluid/operators/math/jit_kernel_layer_norm.cc +++ b/paddle/fluid/operators/math/jit_kernel_layer_norm.cc @@ -79,16 +79,16 @@ class LayerNormKernelImpl : public LayerNormKernel { } }; -#define INTRIAVX_FLOAT(isa, block) \ +#define INTRIAVX_FLOAT(isa, jit_block) \ template <> \ - LayerNormKernelImpl::LayerNormKernelImpl(int right) \ + LayerNormKernelImpl::LayerNormKernelImpl(int right) \ : LayerNormKernel() { \ this->num_ = right; \ this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \ this->end_ = this->num_ - this->rest_; \ } \ template <> \ - void LayerNormKernelImpl::Compute( \ + void LayerNormKernelImpl::Compute( \ float* x, float* out, float* mean, float* var, const float* scale, \ const float* bias, int height, const float epsilon) const { \ __m256 sum; \ @@ -97,6 +97,7 @@ class LayerNormKernelImpl : public LayerNormKernel { __m256 tmp; \ size_t offset; \ size_t j; \ + size_t block = YMM_FLOAT_BLOCK; \ __m256 reverse_num_vec = \ _mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \ __m256 epsilon_vec = _mm256_set1_ps(epsilon); \ @@ -221,12 +222,14 @@ INTRIAVX_FLOAT(platform::avx, kEQ8); INTRIAVX_FLOAT(platform::avx, kGT8LT16); INTRIAVX_FLOAT(platform::avx, kEQ16); INTRIAVX_FLOAT(platform::avx, kGT16); -#endif -#ifdef __AVX2__ INTRIAVX_FLOAT(platform::avx2, kEQ8); INTRIAVX_FLOAT(platform::avx2, kGT8LT16); INTRIAVX_FLOAT(platform::avx2, kEQ16); INTRIAVX_FLOAT(platform::avx2, kGT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ8); +INTRIAVX_FLOAT(platform::avx512f, kGT8LT16); +INTRIAVX_FLOAT(platform::avx512f, kEQ16); +INTRIAVX_FLOAT(platform::avx512f, kGT16); #endif #undef INTRIAVX_FLOAT diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 79b7538ad05b0ff348b8264d50b63211b5254e80..9372d63f0bea2b0c9f37d47376d7b7014e381a33 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/math/sequence2batch.cu b/paddle/fluid/operators/math/sequence2batch.cu index be73adfc0cbe37ed8831b5ad34e66bc95e342e9d..9ab13659c1cc5b59d28395bcebcfb43fac5b4544 100644 --- a/paddle/fluid/operators/math/sequence2batch.cu +++ b/paddle/fluid/operators/math/sequence2batch.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/math/sequence2batch.h" namespace paddle { diff --git a/paddle/fluid/operators/math/softmax.cu b/paddle/fluid/operators/math/softmax.cu index 2e9669049e36478549b793e3fa76220825888e21..71d137398267f61d8cc01907d6a9498eef8d62dc 100644 --- a/paddle/fluid/operators/math/softmax.cu +++ b/paddle/fluid/operators/math/softmax.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 413b8ace67bd0a36849373812950834523b62216..921c2e1298906655767c1e7f30dc34b2c564c671 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/mean_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cu b/paddle/fluid/operators/optimizers/adadelta_op.cu index 3fbfee5df05770a1206ab3170d3baffdd20bc77b..562a157f063b44d65254d556d44439eee3636c4c 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cu +++ b/paddle/fluid/operators/optimizers/adadelta_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adadelta_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/adagrad_op.cu b/paddle/fluid/operators/optimizers/adagrad_op.cu index 4efe56855a4bdca41d24f02c29a618a8d4232887..5043468d4c5f721ae0906b1a319eb3ec10b26580 100644 --- a/paddle/fluid/operators/optimizers/adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/adagrad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/optimizers/adagrad_op.h" diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index e8090ebacfe85153aba9e275c9cd1c55fd7af15e..4eb2db717d45a730798eef48d3d10bce9d387c4b 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adam_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/adamax_op.cu b/paddle/fluid/operators/optimizers/adamax_op.cu index e54adcb142fe0d50dad23fe5df14bd6f28220d8a..80e0219d4414db2909b5babc22599d8c0d906c7d 100644 --- a/paddle/fluid/operators/optimizers/adamax_op.cu +++ b/paddle/fluid/operators/optimizers/adamax_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/adamax_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu index 84d65e39329659f82099011f9ec60468d5db6328..dc568802a2b19fee5c8d7fd8d07c929cba8ab4e3 100644 --- a/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/decayed_adagrad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/decayed_adagrad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/ftrl_op.cu b/paddle/fluid/operators/optimizers/ftrl_op.cu index f836b75df93861a0fd670f2a0e786e6a797a4661..acf8e38ca0f5a3cf9899f4898898013e8a2afdd2 100644 --- a/paddle/fluid/operators/optimizers/ftrl_op.cu +++ b/paddle/fluid/operators/optimizers/ftrl_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/ftrl_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu index d1c1f747b70c3ceb806da06e6786a70b62a32995..591dead3b12763e4cd1b9c390a87816ab121fbf8 100644 --- a/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_adagrad_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/proximal_adagrad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/proximal_gd_op.cu b/paddle/fluid/operators/optimizers/proximal_gd_op.cu index 7aa0e1015008eba0c1cf63ba1278dc2b8049b20b..d556fa74f19529d0e2f80d4c6dbfca62498c9dcc 100644 --- a/paddle/fluid/operators/optimizers/proximal_gd_op.cu +++ b/paddle/fluid/operators/optimizers/proximal_gd_op.cu @@ -10,8 +10,6 @@ Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/proximal_gd_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/optimizers/rmsprop_op.cu b/paddle/fluid/operators/optimizers/rmsprop_op.cu index 69e35a309e04f61068d9ff1b6d9f1450d2524253..8b17d6a0204045a9b20adb79dbad72dff5ba267e 100644 --- a/paddle/fluid/operators/optimizers/rmsprop_op.cu +++ b/paddle/fluid/operators/optimizers/rmsprop_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/optimizers/rmsprop_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/pad_constant_like_op.cu b/paddle/fluid/operators/pad_constant_like_op.cu index ea69577904577de353b63491973bf74b7724e18e..9e62a6dc9d34a96c59a08d0e5fd6cdd9f0d6d51d 100644 --- a/paddle/fluid/operators/pad_constant_like_op.cu +++ b/paddle/fluid/operators/pad_constant_like_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/pad_constant_like_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/pad_op.cu b/paddle/fluid/operators/pad_op.cu index 9cddef9cf1d3c43701a4f0ed3f70dcb30c1dbd02..95098a8dca36594c3af60ad8488217e71c673a75 100644 --- a/paddle/fluid/operators/pad_op.cu +++ b/paddle/fluid/operators/pad_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/pad_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu index 63cd47a38a0ff6413c430c6be6284c5f4bfc2595..4897474a485d8417854ffb53aa8ee64321c78ae7 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/sequence_ops/sequence_pool_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 9aadac1a416034a3510dea2916d7577efbc2f8c2..a1fbc7e5fab71df486b53c31464c99e9c4557ccd 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cu b/paddle/fluid/operators/smooth_l1_loss_op.cu index dfbb5c905884b57413587a4f6c33b0238b740c73..e5df479090fabe926f65f58e2300e3ee2027e54d 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cu +++ b/paddle/fluid/operators/smooth_l1_loss_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/smooth_l1_loss_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 6d48796191dd13a45f0c7267bfaf05489f528a9d..cee3e87037e0f1439a08b7b275eedefe357a4b13 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" diff --git a/paddle/fluid/operators/split_selected_rows_op.h b/paddle/fluid/operators/split_selected_rows_op.h index af64607fafc6544047714e731846a2440be219b8..1fef2b3d378c96d087118d0136885e7e29aa237c 100644 --- a/paddle/fluid/operators/split_selected_rows_op.h +++ b/paddle/fluid/operators/split_selected_rows_op.h @@ -72,10 +72,11 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { for (size_t i = 0; i < outs_rows_idx.size(); ++i) { auto rows_idx = outs_rows_idx[i]; outs[i]->set_height(height_sections[i]); + auto dims = x->GetCompleteDims(); + dims[0] = rows_idx.size(); + outs[i]->mutable_value()->mutable_data(dims, x->place()); + outs[i]->mutable_rows()->clear(); if (rows_idx.size() > 0) { - auto dims = x->GetCompleteDims(); - dims[0] = rows_idx.size(); - outs[i]->mutable_value()->mutable_data(dims, x->place()); for (auto idx : rows_idx) { outs[i]->mutable_rows()->push_back(idx - abs_sections[i]); } @@ -98,6 +99,8 @@ class SplitSelectedRowsOpKernel : public framework::OpKernel { } } } + PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(), + "rows should has the same size with tensor dim 0"); } } }; diff --git a/paddle/fluid/operators/squared_l2_distance_op.cu b/paddle/fluid/operators/squared_l2_distance_op.cu index 3e80ae8dd22077c0f9bbdedc24e84f6c339c5a26..c9264da838246efded7d9f85664faf0dc1cec282 100644 --- a/paddle/fluid/operators/squared_l2_distance_op.cu +++ b/paddle/fluid/operators/squared_l2_distance_op.cu @@ -11,9 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU - #include "paddle/fluid/operators/squared_l2_distance_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/squared_l2_norm_op.cu b/paddle/fluid/operators/squared_l2_norm_op.cu index 87830413da3f141f01a97966ae0e2b0501ed600a..e31cfeb78ab8a8d1b55a198fe7a2c647a3dce665 100644 --- a/paddle/fluid/operators/squared_l2_norm_op.cu +++ b/paddle/fluid/operators/squared_l2_norm_op.cu @@ -11,8 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/squared_l2_norm_op.h" namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index db4c2d6c115f04b436db00854ca4b02fea09866b..6125ed07b6d0f92fa317c581a06117dcfa7359ae 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -8,8 +8,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - -#define EIGEN_USE_GPU #include "paddle/fluid/operators/sum_op.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index ee45afab93d079374aefe366425502890854c28d..466bf90c63c1496883995819cdcb19f846e4a302 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -93,7 +93,7 @@ TEST(CudaAtomic, float16) { // unalignment of uint8 void TestUnalign(size_t num, const int shift_bit) { - PADDLE_ENFORCE(num % 2 == 0, "must be a multiple of 2"); + ASSERT_EQ(num % 2, 0); float16 *in1, *in2, *out; float16 *d_in1, *d_in2; size_t size = sizeof(uint8_t) * (num + shift_bit); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 3edd727978010e20203ab994562ce922b6ee0bad..ce1494f1702aa2bbcb255dec122cdae776cbc4a0 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" -#define EIGEN_USE_GPU #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index a85972bdb72ca3119cc14f9e2b810c3875443538..01ee67fd07f848356e801be95d53a61bb5b08e37 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -62,45 +62,54 @@ inline std::string demangle(std::string name) { return name; } #endif struct EnforceNotMet : public std::exception { - std::exception_ptr exp_; std::string err_str_; - EnforceNotMet(std::exception_ptr e, const char* f, int l) : exp_(e) { - static constexpr int TRACE_STACK_LIMIT = 100; + EnforceNotMet(std::exception_ptr e, const char* f, int l) { try { - std::rethrow_exception(exp_); - } catch (const std::exception& exp) { - std::ostringstream sout; + std::rethrow_exception(e); + } catch (std::exception& e) { + Init(e.what(), f, l); + } + } - sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; - sout << "PaddlePaddle Call Stacks: " << std::endl; + template + EnforceNotMet(const char* f, int l, ARGS... args) { + Init(string::Sprintf(args...), f, l); + } + + const char* what() const noexcept override { return err_str_.c_str(); } + + private: + template + inline void Init(StrType what, const char* f, int l) { + static constexpr int TRACE_STACK_LIMIT = 100; + std::ostringstream sout; + + sout << string::Sprintf("%s at [%s:%d]", what, f, l) << std::endl; + sout << "PaddlePaddle Call Stacks: " << std::endl; #if !defined(_WIN32) - void* call_stack[TRACE_STACK_LIMIT]; - auto size = backtrace(call_stack, TRACE_STACK_LIMIT); - auto symbols = backtrace_symbols(call_stack, size); - - Dl_info info; - for (int i = 0; i < size; ++i) { - if (dladdr(call_stack[i], &info) && info.dli_sname) { - auto demangled = demangle(info.dli_sname); - auto addr_offset = static_cast(call_stack[i]) - - static_cast(info.dli_saddr); - sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, - 2 + sizeof(void*) * 2, call_stack[i], - demangled, addr_offset); - } else { - sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, - call_stack[i]); - } + void* call_stack[TRACE_STACK_LIMIT]; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + Dl_info info; + for (int i = 0; i < size; ++i) { + if (dladdr(call_stack[i], &info) && info.dli_sname) { + auto demangled = demangle(info.dli_sname); + auto addr_offset = static_cast(call_stack[i]) - + static_cast(info.dli_saddr); + sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, + 2 + sizeof(void*) * 2, call_stack[i], demangled, + addr_offset); + } else { + sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2, + call_stack[i]); } - free(symbols); + } + free(symbols); #else - sout << "Windows not support stack backtrace yet."; + sout << "Windows not support stack backtrace yet."; #endif - err_str_ = sout.str(); - } + err_str_ = sout.str(); } - - const char* what() const noexcept { return err_str_.c_str(); } }; struct EOFException : public std::exception { @@ -242,13 +251,8 @@ inline void throw_on_error(T e) { throw_on_error(e, ""); } -#define PADDLE_THROW(...) \ - do { \ - throw ::paddle::platform::EnforceNotMet( \ - std::make_exception_ptr( \ - std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \ - __FILE__, __LINE__); \ - } while (false) +#define PADDLE_THROW(...) \ + throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__) #ifndef REPLACE_ENFORCE_GLOG #define PADDLE_ENFORCE(...) \ diff --git a/paddle/fluid/platform/float16.h b/paddle/fluid/platform/float16.h index 9d48557caf75f3571ead3df43a1a93cf65e4b8cb..98afe843c0035ec14ad874508dc02b8d1d3d359c 100644 --- a/paddle/fluid/platform/float16.h +++ b/paddle/fluid/platform/float16.h @@ -71,9 +71,6 @@ struct float16; } // namespace platform } // namespace paddle -// NOTE(): -// Do not move the eigen.h header, otherwise the eigen_vector will failed. -#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/hostdevice.h" #include "unsupported/Eigen/CXX11/Tensor" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index dca0c01ab229ce8c1f8578ad489ef927e15a1068..74b4f2e937b3d3715b13b03e8d3618c0afafb69c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -336,6 +336,8 @@ PYBIND11_MODULE(core, m) { .def("get_tensor", [](SelectedRows &self) { return self.mutable_value(); }, py::return_value_policy::reference) + .def("numel", + [](SelectedRows &self) -> int64_t { return self.value().numel(); }) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) .def("set_rows", diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 52417a1eaf74658b4d87394404eafd5343a3c5fe..a532f94c6dd08de68c56e7af974dbf9a371cf121 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -127,7 +127,8 @@ def __bootstrap__(): 'use_ngraph', 'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads', "dist_threadpool_size", 'eager_delete_tensor_gb', 'allocator_strategy', - 'reader_queue_speed_test_mode', 'print_sub_graph_dir' + 'reader_queue_speed_test_mode', 'print_sub_graph_dir', + 'pe_profile_fname' ] if 'Darwin' not in sysstr: read_env_flags.append('use_pinned_memory') diff --git a/python/paddle/fluid/average.py b/python/paddle/fluid/average.py index 42cd3b36420ef5a17a9a7d981978ba8869809936..40a734af311e2037c1816dce97db123ebedd2f4f 100644 --- a/python/paddle/fluid/average.py +++ b/python/paddle/fluid/average.py @@ -48,6 +48,7 @@ class WeightedAverage(object): Examples: .. code-block:: python + avg = fluid.average.WeightedAverage() avg.add(value=2.0, weight=1) avg.add(value=4.0, weight=2) diff --git a/python/paddle/fluid/tests/unittests/dist_mnist.py b/python/paddle/fluid/tests/unittests/dist_mnist.py index 1cda2711f765622b0bda6f4c688f69352bbd2a6f..1c45a10a9ddde743dce9b343e4d18f568bb05e72 100644 --- a/python/paddle/fluid/tests/unittests/dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/dist_mnist.py @@ -93,7 +93,7 @@ class TestDistMnist2x2(TestDistRunnerBase): # TODO(typhoonzero): fix distributed adam optimizer # opt = fluid.optimizer.AdamOptimizer( # learning_rate=0.001, beta1=0.9, beta2=0.999) - opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) + opt = fluid.optimizer.Momentum(learning_rate=self.lr, momentum=0.9) # Reader train_reader = paddle.batch( diff --git a/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2130f9049c7ee294444282e59c654551f76603 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_concat_mkldnn_op.py @@ -0,0 +1,61 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 + + +class TestMKLDNNConcatOp(TestConcatOp): + def setUp(self): + super(TestMKLDNNConcatOp, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNConcatOp2(TestConcatOp2): + def setUp(self): + super(TestMKLDNNConcatOp2, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNConcatOp3(TestConcatOp3): + def setUp(self): + super(TestMKLDNNConcatOp3, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 26fa20291b52e469066e23b5c29a8e11b40a1270..cedb3383ed4728306c61d7f987850000506457c7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -32,7 +32,7 @@ DEFAULT_BATCH_SIZE = 2 class TestDistRunnerBase(object): - def get_model(self, batch_size=DEFAULT_BATCH_SIZE): + def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -56,6 +56,7 @@ class TestDistRunnerBase(object): return t def run_pserver(self, args): + self.lr = args.lr self.get_model(batch_size=args.batch_size) # NOTE: pserver should not call memory optimize t = self.get_transpiler(args.trainer_id, @@ -71,6 +72,7 @@ class TestDistRunnerBase(object): exe.run(pserver_prog) def run_trainer(self, args): + self.lr = args.lr test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ self.get_model(batch_size=args.batch_size) @@ -189,6 +191,7 @@ def runtime_main(test_class): parser.add_argument( '--use_reader_alloc', action='store_true', required=False) parser.add_argument('--batch_size', required=False, type=int, default=2) + parser.add_argument('--lr', required=False, type=float, default=0.001) parser.add_argument( '--batch_merge_repeat', required=False, type=int, default=1) @@ -234,6 +237,7 @@ class TestDistBase(unittest.TestCase): self._dc_asgd = False # must use with async mode self._use_reader_alloc = True self._nccl2_mode = False + self._lr = 0.001 self._setup_config() self._after_setup_config() @@ -284,7 +288,8 @@ class TestDistBase(unittest.TestCase): batch_size=DEFAULT_BATCH_SIZE, batch_merge_repeat=1): - cmd = "%s %s --role trainer" % (self._python_interp, model) + cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model, + self._lr) if batch_size != DEFAULT_BATCH_SIZE: cmd += " --batch_size %d" % batch_size if batch_merge_repeat > 1: @@ -330,13 +335,13 @@ class TestDistBase(unittest.TestCase): ps0_ep, ps1_ep = self._ps_endpoints.split(",") - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f" tr0_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 0, ps0_ep, self._trainers) + 0, ps0_ep, self._trainers, self._lr) tr1_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 1, ps1_ep, self._trainers) + 1, ps1_ep, self._trainers, self._lr) if self._sync_mode: tr0_cmd += " --sync_mode" @@ -425,13 +430,13 @@ class TestDistBase(unittest.TestCase): worker_endpoints = self._ps_endpoints.split(",") w0_ep, w1_ep = worker_endpoints - tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2" + tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f" tr0_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 0, w0_ep) + 0, w0_ep, self._lr / 2) tr1_cmd = tr_cmd % \ (self._python_interp, model, self._ps_endpoints, - 1, w1_ep) + 1, w1_ep, self._lr / 2) if self._mem_opt: tr0_cmd += " --mem_opt" diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist.py b/python/paddle/fluid/tests/unittests/test_dist_mnist.py index 630bed198f4fc382d716373ea872e24b1b45bbf3..49a2ca40e3cb1dd35027345e9c38eb8b6912d2cd 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist.py @@ -36,7 +36,7 @@ class TestDistMnistNCCL2(TestDistBase): def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): - self.check_with_place("dist_mnist.py", delta=1) + self.check_with_place("dist_mnist.py", delta=1e-5) class TestDistMnist2x2Lars(TestDistBase): diff --git a/python/paddle/fluid/tests/unittests/test_regularizer.py b/python/paddle/fluid/tests/unittests/test_regularizer.py index 20f91cf4485f2e79c20fe90143c8b7deebb9fc49..62994eec7e7f56267a0990d9a5e3b5c62d7d5fe4 100644 --- a/python/paddle/fluid/tests/unittests/test_regularizer.py +++ b/python/paddle/fluid/tests/unittests/test_regularizer.py @@ -15,7 +15,12 @@ from __future__ import print_function import unittest - +from functools import partial +import contextlib +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer import paddle.fluid.regularizer as regularizer @@ -97,5 +102,134 @@ class TestL1DecayRegularizer(unittest.TestCase): self.assertEqual(block.ops[-3].type, 'sign') +def bow_net(data, + label, + dict_dim, + is_sparse=False, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + fluid/PaddleNLP/text_classification/nets.py + """ + emb = fluid.layers.embedding( + input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return avg_cost + + +class TestRegularizer(unittest.TestCase): + def setUp(self): + self.word_dict = paddle.dataset.imdb.word_dict() + reader = paddle.batch( + paddle.dataset.imdb.train(self.word_dict), batch_size=8)() + self.train_data = [next(reader) for _ in range(5)] + + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + @contextlib.contextmanager + def scope_prog_guard(self, main_prog, startup_prog): + scope = fluid.core.Scope() + with fluid.unique_name.guard(): + with fluid.scope_guard(scope): + with fluid.program_guard(main_prog, startup_prog): + yield + + def run_program(self, place, feed_list): + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + exe.run(fluid.default_startup_program()) + + main_prog = fluid.default_main_program() + param_list = [var.name for var in main_prog.block(0).all_parameters()] + + param_sum = [] + for data in self.train_data: + out = exe.run(main_prog, + feed=feeder.feed(data), + fetch_list=param_list) + p_sum = 0 + for v in out: + p_sum += np.sum(np.abs(v)) + param_sum.append(p_sum) + return param_sum + + def check_l2decay_regularizer(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost = model(data, label, len(self.word_dict)) + + optimizer = fluid.optimizer.Adagrad( + learning_rate=0.1, + regularization=fluid.regularizer.L2Decay(1.0)) + optimizer.minimize(avg_cost) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def check_l2decay(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with self.scope_prog_guard( + main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost_l2 = model(data, label, len(self.word_dict)) + + param_list = fluid.default_main_program().block(0).all_parameters() + para_sum = [] + for para in param_list: + para_mul = fluid.layers.square(x=para) + para_sum.append(fluid.layers.reduce_sum(input=para_mul)) + avg_cost_l2 += fluid.layers.sums(para_sum) * .5 + + optimizer = fluid.optimizer.Adagrad(learning_rate=0.1) + optimizer.minimize(avg_cost_l2) + param_sum = self.run_program(place, [data, label]) + return param_sum + + def test_l2(self): + for place in self.get_places(): + dense_sparse_p_sum = [] + for sparse in [True, False]: + model = partial(bow_net, is_sparse=sparse) + framework_l2 = self.check_l2decay_regularizer(place, model) + l2 = self.check_l2decay(place, model) + assert len(l2) == len(framework_l2) + for i in range(len(l2)): + assert np.isclose(a=framework_l2[i], b=l2[i], rtol=5e-5) + dense_sparse_p_sum.append(framework_l2) + + assert len(dense_sparse_p_sum[0]) == len(dense_sparse_p_sum[1]) + for i in range(len(dense_sparse_p_sum[0])): + assert np.isclose( + a=dense_sparse_p_sum[0][i], + b=dense_sparse_p_sum[1][i], + rtol=5e-5) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py index 50204b8a77c187aa695da83860960566448d290f..f8847e1570dc47d432777faa15f4004f1a7111a6 100644 --- a/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_selected_rows_op.py @@ -63,6 +63,7 @@ class TestSpliteSelectedRows(unittest.TestCase): # expected output selected rows expected_out0_rows = [0, 4] expected_out1_rows = [0, 2] + expected_out2_rows = [] expected_out4_rows = [0] op = Operator( @@ -75,6 +76,7 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertEqual(outs[0].rows(), expected_out0_rows) self.assertEqual(outs[1].rows(), expected_out1_rows) + self.assertEqual(outs[2].rows(), expected_out2_rows) self.assertEqual(outs[4].rows(), expected_out4_rows) self.assertEqual(outs[0].height(), height_sections[0]) @@ -84,6 +86,9 @@ class TestSpliteSelectedRows(unittest.TestCase): self.assertAlmostEqual(4.0, np.array(outs[1].get_tensor())[1, 1]) self.assertAlmostEqual(8.0, np.array(outs[4].get_tensor())[0, 1]) + self.assertEqual(outs[2].numel(), 0) + self.assertEqual(outs[3].numel(), 0) + def check_grad_with_place(self, place): scope = core.Scope() height = 10