diff --git a/cmake/FindJeMalloc.cmake b/cmake/FindJeMalloc.cmake index 7911f77c4c35b5cf0fa47ff98282986eef974832..b95287160ba610b2dfa93ba15e7c7c8214d80ac1 100644 --- a/cmake/FindJeMalloc.cmake +++ b/cmake/FindJeMalloc.cmake @@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL mark_as_advanced( JEMALLOC_LIBRARIES JEMALLOC_INCLUDE_DIR) + +if (JEMALLOC_FOUND) + add_library(jemalloc::jemalloc UNKNOWN IMPORTED) + set_target_properties(jemalloc::jemalloc PROPERTIES + IMPORTED_LOCATION ${JEMALLOC_LIBRARIES} + INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") +endif() diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 10ecdf0ea873718a23ece8fa97faa3728652c188..16432ce2b803f6d21bbf47200eda5404269b750f 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -2,9 +2,11 @@ if(NOT WITH_GPU) return() endif() -set(paddle_known_gpu_archs "30 35 50 52 60 61 70 75") +set(paddle_known_gpu_archs "30 35 50 52 60 61 70") set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs8 "30 35 50 52 60 61") +set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70") +set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75") ###################################################################################### # A function for automatic detection of GPUs installed (if autodetection is enabled) @@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x # warning for now. list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") add_definitions("-DPADDLE_CUDA_BINVER=\"80\"") +elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs9}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"90\"") +elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x + set(paddle_known_gpu_archs ${paddle_known_gpu_archs10}) + list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED") + list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__") + add_definitions("-DPADDLE_CUDA_BINVER=\"100\"") endif() include_directories(${CUDA_INCLUDE_DIRS}) diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake index 5a78a1d1b7dea0d95ae3fa2c9f39679899dd1bcb..12412a51a0fd1aaa9702bd4547fb935d94012ada 100644 --- a/cmake/external/boost.cmake +++ b/cmake/external/boost.cmake @@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost") # checked that the devtools package of CentOS 6 installs boost 1.41.0. # So we use 1.41.0 here. set(BOOST_VER "1.41.0") -if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL)) - message(STATUS "use pre defined download url") - set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) - set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) -endif() +set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE) +set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE) MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index a9b99e9ab87c724ac7062e3a20b247bf6ea44634..03f0dee85911bdaa0312b624114b7f4aef1fb723 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -55,7 +55,7 @@ ExternalProject_Add( ${MKLDNN_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} DEPENDS ${MKLDNN_DEPENDS} - GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" + GIT_REPOSITORY "https://github.com/intel/mkl-dnn.git" GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a" PREFIX ${MKLDNN_SOURCES_DIR} UPDATE_COMMAND "" diff --git a/cmake/external/mklml.cmake b/cmake/external/mklml.cmake index 96127e78d64a9df7dd32730d27c939b88fc0c739..43322a257a02c3fd756078db6fe20b582826066a 100644 --- a/cmake/external/mklml.cmake +++ b/cmake/external/mklml.cmake @@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML}) return() ENDIF(NOT ${WITH_MKLML}) +IF(APPLE) + MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.") + SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE) + return() +ENDIF() + INCLUDE(ExternalProject) SET(MKLML_DST_DIR "mklml") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") @@ -23,32 +29,24 @@ SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) -if(WIN32) +SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") + +SET(TIME_VERSION "2019.0.1.20181227") +IF(WIN32) + SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) -else() +ELSE() + SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE) + SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) -endif() -SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib") - -IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL)) - MESSAGE(STATUS "use pre defined download url") - if(WIN32) - SET(MKLML_VER "mklml_win_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE) - elseif(APPLE) - SET(MKLML_VER "mklml_mac_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) - else() - SET(MKLML_VER "mklml_lnx_2019.0.1.20180928" CACHE STRING "" FORCE) - SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE) - ENDIF() -endif() +ENDIF() SET(MKLML_PROJECT "extern_mklml") MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}") diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 4e31392b9898f7af3457b1a70a0ab5b8053f70c9..05293b8b06b55bb0b83a30c7eb059efe0b61e57e 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -117,7 +117,7 @@ function(common_link TARGET_NAME) endif() if (WITH_JEMALLOC) - target_link_libraries(${TARGET_NAME} ${JEMALLOC_LIBRARIES}) + target_link_libraries(${TARGET_NAME} jemalloc::jemalloc) endif() endfunction() diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 179aa145284ed62c2c96669499b277df45ea8066..c1ba6606f1064750a9d7e087ded1ec3634bcc4a5 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass fuse_elewise_add_act_pass multi_batch_merge_pass - memory_optimize_pass) + memory_optimize_pass lock_free_optimize_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index a68b69e0264e2f202dd41b56faf2f589118a3a53..df0ff772c9d35c88ec5a6112525c56aa92d359b9 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -232,3 +232,4 @@ USE_PASS(analysis_var_pass); USE_PASS(sequential_execution_pass); USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); +USE_PASS(lock_free_optimize_pass); diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6d795e1e2d5407ecacf5fb4af539919d72bff404..6e6db3d3efbc9fbb17e7ee45402dd4cb7f4f7a34 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) +pass_library(lock_free_optimize_pass base) pass_library(fc_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.cc b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..92e897ca9ce02ed67f026fd08062842e3bafa098 --- /dev/null +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.cc @@ -0,0 +1,358 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/lock_free_optimize_pass.h" + +#include +#include +#include + +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +const char kSumGradOpName[] = "sum"; +// TODO(minqiyang): only support sgd at current time, please add +// other optimizers later. +const char kOptimizerType[] = "sgd"; + +std::unique_ptr LockFreeOptimizePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + + // We could collect all weights' name from SGD, where + // W1 <- SGD(W0, Grad0) + std::unordered_set weight_var_set; + for (auto* node : graph->Nodes()) { + if (IsOpNamed(node, kOptimizerType)) { + auto& param_out_vars = node->Op()->Output("ParamOut"); + PADDLE_ENFORCE(param_out_vars.size() == 1u); + weight_var_set.insert(param_out_vars[0]); + } + } + + // find all grad's merge op via weight name, where + // Grad0 <- SUM(Grad1, Grad2, Grad3 ...) + std::unordered_set grad_sum_op_set; + for (ir::Node* node : graph->Nodes()) { + if (IsOpNamed(node, kSumGradOpName)) { + for (ir::Node* output : node->outputs) { + // strip the last grad suffix @GRAD + std::string var_name = output->Name(); + const std::string suffix(kGradVarSuffix); + if (var_name != suffix && var_name.size() > suffix.size() && + var_name.substr(var_name.size() - suffix.size()) == suffix) { + // if so then strip them off + var_name = var_name.substr(0, var_name.size() - suffix.size()); + if (weight_var_set.find(var_name) != weight_var_set.end()) { + grad_sum_op_set.insert(node); + break; + } + } + } + } + } + + // get the forward op and backward op pairs, where + // out <- forward(X, W) + // Grad1 <- backward(out, X') + // Grad0 <- SUM(Grad1, Grad2, Grad3 ...) + // W0 <- SGD(W1, Grad0) + for (ir::Node* node : grad_sum_op_set) { + for (ir::Node* merged_grad_var : node->outputs) { + // find the optimizers connected with sum op + if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) && + merged_grad_var->outputs.size() == 1u) { + ir::Node* opt_node = merged_grad_var->outputs[0]; + VLOG(3) << "Found opt node " << opt_node->Name(); + + // find the backward op connected with sum op + for (ir::Node* unmerged_grad_var : node->inputs) { + if (IsVarNameContains(unmerged_grad_var, kGradVarSuffix) && + unmerged_grad_var->inputs.size() == 1u) { + ir::Node* backward_op = unmerged_grad_var->inputs[0]; + + VLOG(3) << "Found backward_op " << backward_op->Name(); + + // find the forward op related to the backward op + ir::Node* forward_op = + FindForwardOpViaBackwardOp(graph.get(), backward_op); + + VLOG(3) << "Found forward_op " << forward_op->Name(); + + PADDLE_ENFORCE(forward_op); + + Node* new_optimizer_node = CreateNewSGDNode( + graph.get(), forward_op, backward_op, node, opt_node); + + PADDLE_ENFORCE(new_optimizer_node); + } + } + } + } + } + + // Remove the sum_op and its' outputs and connected Optimizers + for (Node* sum_op : grad_sum_op_set) { + for (Node* sum_op_output : sum_op->outputs) { + for (Node* optimize_op : sum_op_output->outputs) { + if (optimize_op->NodeType() == Node::Type::kOperation && + optimize_op->Name() == kOptimizerType) { + VLOG(3) << "remove optimize_op: " << optimize_op->Name() << "_" + << optimize_op->id(); + graph->RemoveNode(optimize_op); + } + } + VLOG(3) << "remove sum_op_output: " << sum_op_output->Name() << "_" + << sum_op_output->id(); + graph->RemoveNode(sum_op_output); + } + VLOG(3) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id(); + graph->RemoveNode(sum_op); + } + + for (auto* node : graph->Nodes()) { + for (Node* output_node : node->outputs) { + if (output_node->Name() == "sgd") { + VLOG(3) << "Node link to SGD: " << node->Name() << "_" << node->id() + << " --> " << output_node->Name() << "_" << output_node->id(); + for (Node* input_node : node->inputs) { + VLOG(3) << "SGD Input link: " << input_node->Name() << "_" + << input_node->id() << " --> " << node->Name() << "_" + << node->id(); + } + } + } + } + + return graph; +} + +ir::Node* LockFreeOptimizePass::CreateNewSGDNode( + ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node, + ir::Node* grad_sum_node, ir::Node* optimize_node) const { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(forward_node); + PADDLE_ENFORCE(backward_node); + PADDLE_ENFORCE(grad_sum_node); + PADDLE_ENFORCE(optimize_node); + + // find the grad var node between the grad sum node and backward_node + std::vector grad_vars = + FindConnectedNode(backward_node, grad_sum_node); + ir::Node* grad_node = nullptr; + for (ir::Node* node : grad_vars) { + if (!ir::IsControlDepVar(*node)) { + grad_node = node; + } + } + PADDLE_ENFORCE(grad_node); + + // create a new SGD node + OpDesc* old_desc = optimize_node->Op(); + // keep with the same block between new optimizer and the old one + OpDesc new_desc(*old_desc, old_desc->Block()); + new_desc.SetInput("Param", old_desc->Input("Param")); + new_desc.SetInput("LearningRate", old_desc->Input("LearningRate")); + new_desc.SetInput("Grad", std::vector({grad_node->Name()})); + new_desc.SetOutput("ParamOut", old_desc->Output("ParamOut")); + + std::vector op_role_vars = boost::get>( + new_desc.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName())); + // replace the second op role var, because the grad name was + // changed in new optimizer + op_role_vars.pop_back(); + op_role_vars.push_back(grad_node->Name()); + new_desc.SetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), + op_role_vars); + new_desc.SetType(kOptimizerType); + + // set backward op's op role var, this will be used to + // set device_id in multi_device_pass + backward_node->Op()->SetAttr( + framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), op_role_vars); + // backward_node->Op()->SetAttr( + // framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), {}); + + // keep with the same output nodes between new optimizer and the + // old one + Node* sgd_node = graph->CreateOpNode(&new_desc); + + // change all outputs of the optimize_node to the new one + ReplaceAllDownstreamNode(optimize_node, sgd_node); + + // find connected node between forward node and optimize node + // and replace the optimize node to new sgd node + std::vector forward_opt_connected_nodes = + FindConnectedNode(forward_node, optimize_node); + for (ir::Node* node : forward_opt_connected_nodes) { + ReplaceUpstreamNode(node, optimize_node, sgd_node); + } + + // find connected node between backward node and optimize node + // and replace the optimize node to new sgd node + std::vector backward_opt_connected_nodes = + FindConnectedNode(backward_node, optimize_node); + for (ir::Node* node : backward_opt_connected_nodes) { + ReplaceUpstreamNode(node, optimize_node, sgd_node); + } + + // SGD must have only one param and LR in + PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u); + PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u); + + // LR and weight nodes should be copied + for (Node* upstream_node : optimize_node->inputs) { + if (upstream_node->Name() == old_desc->Input("LearningRate")[0] || + upstream_node->Name() == old_desc->Input("Param")[0]) { + ReplaceUpstreamNode(upstream_node, optimize_node, sgd_node); + } + } + + VLOG(3) << "Create new opt node" << sgd_node->Name() << "_" << sgd_node->id(); + + return sgd_node; +} + +std::vector LockFreeOptimizePass::FindConnectedNode( + ir::Node* upstream_node, ir::Node* downstream_node) const { + std::vector result; + for (ir::Node* out_node : upstream_node->outputs) { + for (ir::Node* in_node : downstream_node->inputs) { + if (in_node == out_node) { + result.push_back(in_node); + } + } + } + + return result; +} + +void LockFreeOptimizePass::ReplaceUpstreamNode( + ir::Node* upstream_node, ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const { + PADDLE_ENFORCE(upstream_node); + PADDLE_ENFORCE(old_optimizer_node); + PADDLE_ENFORCE(new_optimizer_node); + + // Remove the old_optimizer_node from upstream_node's outputs vector + auto& output_node_vec = upstream_node->outputs; + for (auto output_node_iter = output_node_vec.begin(); + output_node_iter != output_node_vec.end();) { + if (*output_node_iter == old_optimizer_node) { + output_node_vec.erase(output_node_iter); + break; + } else { + ++output_node_iter; + } + } + + // Add the new_optimizer_node to upstream_node's outputs vector + output_node_vec.emplace_back(new_optimizer_node); + new_optimizer_node->inputs.emplace_back(upstream_node); +} + +void LockFreeOptimizePass::ReplaceAllDownstreamNode( + ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const { + PADDLE_ENFORCE(old_optimizer_node); + PADDLE_ENFORCE(new_optimizer_node); + + for (ir::Node* downstream_node : old_optimizer_node->outputs) { + // Remove the old_optimizer_node from downstream_node's inputs vector + auto& input_node_vec = downstream_node->inputs; + for (auto input_node_iter = input_node_vec.begin(); + input_node_iter != input_node_vec.end();) { + if (*input_node_iter == old_optimizer_node) { + input_node_vec.erase(input_node_iter); + break; + } else { + ++input_node_iter; + } + } + + // Add the new_optimizer_node to downstream_node's inputs vector + input_node_vec.emplace_back(new_optimizer_node); + new_optimizer_node->outputs.emplace_back(downstream_node); + } +} + +ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp( + ir::Graph* graph, ir::Node* backward_node) const { + PADDLE_ENFORCE(graph); + PADDLE_ENFORCE(backward_node); + + // strip the suffix _grad of backward_node's name + std::string forward_op_name = backward_node->Name(); + const std::string suffix("_grad"); + if (forward_op_name != suffix && forward_op_name.size() > suffix.size() && + forward_op_name.substr(forward_op_name.size() - suffix.size()) == + suffix) { + // if so then strip them off + forward_op_name = + forward_op_name.substr(0, forward_op_name.size() - suffix.size()); + } else { + LOG(WARNING) << "Illegal backward node's name " << backward_node->Name() + << " id " << backward_node->id(); + + return nullptr; + } + + for (ir::Node* node : graph->Nodes()) { + if (node->Name() == forward_op_name) { + if (node->outputs.size() == 0u) { + // if forward_node has no output, then it has NO grad op + continue; + } + + // check whether all inputs of the backward_op that ends_with @GRAD + // comes from the output of forward_op is the input of the backward_op + bool is_related_forward_node = true; + for (ir::Node* backward_input : backward_node->inputs) { + if (IsVarNameEndsWith(backward_input, kGradVarSuffix)) { + bool meets_correct_output = false; + for (ir::Node* forward_output : node->outputs) { + if (forward_output->Name() + kGradVarSuffix == + backward_input->Name()) { + meets_correct_output = true; + break; + } + } + + if (!meets_correct_output) { + is_related_forward_node = false; + break; + } + } + } + + if (is_related_forward_node) { + return node; + } + } + } + + return nullptr; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(lock_free_optimize_pass, + paddle::framework::ir::LockFreeOptimizePass); diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.h b/paddle/fluid/framework/ir/lock_free_optimize_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..7310f596f8a3170e84840be4bab8390b780b6577 --- /dev/null +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.h @@ -0,0 +1,130 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ +#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ + +#include +#include + +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +/* +* Remove the sum op of all gradients of the backward op. +* And remove the dependecies of the optimizer related to the +* same backward op. +* +* Before this pass: +* +* forward_op1 forward_op2 +* | | +* grad_op1 grad_op2 +* \ / +* \ / +* sum_op +* | +* sgd_op +* +* After this pass: +* forward_op1 forward_op2 +* | | +* grad_op1 grad_op2 +* | | +* sgd_op1 sgd_op2 +* +* sgd_op1 and sgd_op2 will update the same weight which holds the same +* memory, so we could benefits from the acceleration +*/ +class LockFreeOptimizePass : public Pass { + public: + virtual ~LockFreeOptimizePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + private: + // Create a new sgd node via current optimizer node + ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node, + ir::Node* backward_node, ir::Node* grad_sum_node, + ir::Node* optimize_node) const; + + // Replace the input weight's optimizers + void ReplaceUpstreamNode(ir::Node* upstream_node, + ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const; + + // Replace the output weight's optimizers + void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node, + ir::Node* new_optimizer_node) const; + + // Find all weight variables in graph + bool FindAllWeightVars(ir::Graph* graph) const; + + // Find the forward_op node via the backward_op node + ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph, + ir::Node* backward_node) const; + + std::vector FindConnectedNode(ir::Node* upstream_node, + ir::Node* downstream_node) const; + + inline bool IsOpNamed(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kOperation && node->Name() == name; + } + + inline bool IsVarNamed(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && node->Name() == name; + } + + inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && + boost::algorithm::ends_with(node->Name(), name); + } + + inline bool IsVarNameContains(ir::Node* node, const std::string& name) const { + PADDLE_ENFORCE(node); + + return node->NodeType() == Node::Type::kVariable && + node->Name().find(name) != std::string::npos; + } + + inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const { + PADDLE_ENFORCE(ctrl_dep_node); + PADDLE_ENFORCE(node); + + return IsControlDepVar(*ctrl_dep_node) && + ctrl_dep_node->inputs.size() >= 1u && + ctrl_dep_node->inputs[0] == node; + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle + +#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_ diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index a5742dbd3d66a47ca108768d875e5764a0e62f4f..953618560913229cd1e47659ad61e621efc10ed1 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) { } Variable* Scope::Var(std::string* name) { - auto new_name = string::Sprintf("%p.%d", this, vars_.size()); + SCOPE_VARS_WRITER_LOCK + auto new_name = std::to_string(reinterpret_cast(this)) + "." + + std::to_string(vars_.size()); if (name != nullptr) { *name = new_name; } - SCOPE_VARS_WRITER_LOCK return VarInternal(new_name); } diff --git a/paddle/fluid/framework/var_type_traits.cc b/paddle/fluid/framework/var_type_traits.cc index c3c5bab23b92a0274cf786ea2f18d8246706162f..a37b1fbab8cfd0642beaf725c02941002b2176b3 100644 --- a/paddle/fluid/framework/var_type_traits.cc +++ b/paddle/fluid/framework/var_type_traits.cc @@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder { } // namespace detail -const std::type_index &ToTypeIndex(int var_id) { +const std::type_index &VarTraitIdToTypeIndex(int var_id) { return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); } -const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } +const char *ToTypeName(int var_id) { + return VarTraitIdToTypeIndex(var_id).name(); +} -int ToTypeId(const std::type_index &type) { +int TypeIndexToVarTraitId(const std::type_index &type) { return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); } diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index cc68cf2ab8e1bbc8a57cf97a2084610440a75f85..733542e4972b16a71f9e76c3076b424b7a901066 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -66,8 +66,8 @@ namespace paddle { namespace framework { const char *ToTypeName(int var_id); -const std::type_index &ToTypeIndex(int var_id); -int ToTypeId(const std::type_index &type); +const std::type_index &VarTraitIdToTypeIndex(int var_id); +int TypeIndexToVarTraitId(const std::type_index &type); namespace detail { diff --git a/paddle/fluid/framework/var_type_traits_test.cc b/paddle/fluid/framework/var_type_traits_test.cc index 00840d634d802cfe17fbff127a75606cb5e2cf79..a47275e1ca25a4f66e67b4986ec78e49ea952a51 100644 --- a/paddle/fluid/framework/var_type_traits_test.cc +++ b/paddle/fluid/framework/var_type_traits_test.cc @@ -45,10 +45,11 @@ struct TypeIndexChecker { constexpr auto kId = VarTypeTrait::kId; std::type_index actual_type(typeid(Type)); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); - EXPECT_EQ(ToTypeIndex(kId), actual_type); - EXPECT_EQ(ToTypeId(actual_type), kId); - EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type); - EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId); + EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)), + actual_type); + EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId); EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index f84e1ab6b827b3b96d0a503394d95b06ed25a3d2..4c84d02d8679c4d42c0d02ae83e7f869c0f5ce8b 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) { i++) { LOG(INFO) << "data: " << static_cast(outputs.front().data.data())[i] << " result: " << result[i]; - PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i], - result[i]); + EXPECT_NEAR(static_cast(outputs.front().data.data())[i], result[i], + 1e-3); } } diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index d3ea511d8f4d8cbec1be57633391f00e29a3e6e9..add9b70f2cd960a94232b35edb928ab4115cbff0 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps} ir_graph_build_pass ir_analysis_pass analysis_passes + subgraph_detector CACHE INTERNAL "") diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 437005825db7e0718b52ac830dd56ac87069ed39..bde2791add4075be6949703dfbea634966d25c1c 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -190,6 +190,26 @@ void BenchGRUKernel() { } } +template +void BenchSeqPoolKernel() { + std::vector pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; + for (auto type : pool_types) { + for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); + for (int h : TestSizes()) { + attr.h = h; + std::vector x(h * w), y(w); + RandomVec(h * w, x.data(), -2.f, 2.f); + const T* x_data = x.data(); + T* y_data = y.data(); + BenchAllImpls, PlaceType>(attr, x_data, + y_data, &attr); + } + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -228,4 +248,7 @@ int main(int argc, char* argv[]) { BenchGRUKernel(); BenchGRUKernel(); BenchGRUKernel(); + + // seq pool function + BenchSeqPoolKernel(); } diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 8a540108302f77e1ca3bfe1db0013d76a22d5eb4..2b8c758a032fd7edff0d4b7e23bd8e685eb3ab15 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1) USE_JITKERNEL_GEN(kGRUHtPart1) USE_JITKERNEL_GEN(kGRUHtPart2) USE_JITKERNEL_GEN(kNCHW16CMulNC) +USE_JITKERNEL_GEN(kSeqPool) diff --git a/paddle/fluid/operators/jit/gen/seqpool.cc b/paddle/fluid/operators/jit/gen/seqpool.cc new file mode 100644 index 0000000000000000000000000000000000000000..530d24ee1fb7d9da84102641e1d4d2ab08ab1860 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/seqpool.h" +#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void SeqPoolJitCode::genCode() { + constexpr int block = YMM_FLOAT_BLOCK; + constexpr int max_num_regs = 8; + const int num_block = w_ / block; + const int num_groups = num_block / max_num_regs; + int rest_num_regs = num_block % max_num_regs; + mov(reg32_int_h, dword[param_attr]); + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + mov(reg_tmp, reinterpret_cast(exp_float_consts)); + vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]); + mov(reg_tmp, reinterpret_cast(fp_h_)); + fild(dword[param_attr]); + fstp(dword[reg_tmp]); + vmovss(xmm_t(0), ptr[reg_tmp]); + if (type_ == SeqPoolType::kSqrt) { + vsqrtps(xmm_t(0), xmm_t(0)); + } + vdivps(xmm_t(1), xmm_t(1), xmm_t(0)); + vmovss(ptr[reg_tmp], xmm_t(1)); + } + const int group_len = max_num_regs * block * sizeof(float); + for (int g = 0; g < num_groups; ++g) { + pool_height(g * group_len, block, max_num_regs); + } + if (rest_num_regs > 0) { + pool_height(num_groups * group_len, block, rest_num_regs); + } + // part of rest_w * height + const int rest = w_ % block; + pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs); + ret(); +} + +class SeqPoolCreator : public JitCodeCreator { + public: + bool UseMe(const seq_pool_attr_t& attr) const override { + return platform::MayIUse(platform::avx); + } + size_t CodeSize(const seq_pool_attr_t& attr) const override { + return 96 + + ((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) * + 4 /* load, mul and save */ + + 256) * + 8; + } + std::unique_ptr CreateJitCode( + const seq_pool_attr_t& attr) const override { + PADDLE_ENFORCE_GT(attr.w, 0); + PADDLE_ENFORCE_GT(attr.h, 0); + return make_unique(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator); diff --git a/paddle/fluid/operators/jit/gen/seqpool.h b/paddle/fluid/operators/jit/gen/seqpool.h new file mode 100644 index 0000000000000000000000000000000000000000..fcbbb3c84c562e2ba57110134bf07bb218b41edb --- /dev/null +++ b/paddle/fluid/operators/jit/gen/seqpool.h @@ -0,0 +1,214 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class SeqPoolJitCode : public JitCode { + public: + explicit SeqPoolJitCode(const seq_pool_attr_t& attr, + size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) { + if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg || + type_ == SeqPoolType::kSqrt)) { + LOG(FATAL) << "Only support sum pool yet "; + } + fp_h_[0] = 1.f; + this->genCode(); + } + + virtual const char* name() const { + std::string base = "SeqPoolJitCode"; + if (type_ == SeqPoolType::kSum) { + base += "_Sum"; + } else if (type_ == SeqPoolType::kAvg) { + base += "_Avg"; + } else if (type_ == SeqPoolType::kSqrt) { + base += "_Sqrt"; + } + base += ("_W" + std::to_string(w_)); + return base.c_str(); + } + void genCode() override; + + protected: + template + void pool_height(int w_offset, int block, int max_num_regs) { + int offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i), ptr[param_src + offset]); + offset += sizeof(float) * block; + } + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + mov(reg_ptr_src_i, reg_tmp); + for (int i = 0; i < max_num_regs; ++i) { + vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]); + // sum anyway + vaddps(JMM(i), JMM(i), JMM(i + max_num_regs)); + add(reg_ptr_src_i, sizeof(float) * block); + } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); + } + L(l_h_done); + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + mov(reg_tmp, reinterpret_cast(fp_h_)); + vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]); + } + offset = w_offset; + for (int i = 0; i < max_num_regs; ++i) { + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + vmulps(JMM(i), JMM(i), JMM(max_num_regs)); + } + vmovups(ptr[param_dst + offset], JMM(i)); + offset += sizeof(float) * block; + } + } + + void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) { + const int rest_used_num_regs = load_rest(rest, w_offset, 0); + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + cmp(reg32_int_h, 1); + Label l_next_h, l_h_done; + jle(l_h_done, T_NEAR); + mov(reg_h_i, 1); + mov(reg_tmp, param_src); + add(reg_tmp, w_ * sizeof(float) + w_offset); + L(l_next_h); + { + int reg_idx = 0; + mov(reg_ptr_src_i, reg_tmp); + if (has_block4) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 4); + reg_idx++; + } + if (has_block2) { + vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + add(reg_ptr_src_i, sizeof(float) * 2); + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); + reg_idx++; + } + PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, + "All heights should use same regs"); + for (int i = 0; i < reg_idx; ++i) { + vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); + } + inc(reg_h_i); + add(reg_tmp, w_ * sizeof(float)); + cmp(reg_h_i, reg32_int_h); + jl(l_next_h, T_NEAR); + } + L(l_h_done); + // save right now + if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) { + mov(reg_tmp, reinterpret_cast(fp_h_)); + vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]); + for (int i = 0; i < rest_used_num_regs; ++i) { + vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs)); + } + } + save_rest(rest, w_offset); + } + + // return the number of used regs, use start from reg 0 + int load_rest(int rest, int w_offset, const int num_shift_regs, + const int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]); + reg_idx++; + } + return reg_idx; + } + + // use reg start from 0 + void save_rest(int rest, int w_offset, int reg_start = 0) { + const bool has_block4 = rest / 4 > 0; + const bool has_block2 = (rest % 4) / 2 > 0; + const bool has_block1 = (rest % 2) == 1; + int reg_idx = reg_start; + if (has_block4) { + vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 4; + reg_idx++; + } + if (has_block2) { + vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx)); + w_offset += sizeof(float) * 2; + reg_idx++; + } + if (has_block1) { + vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx)); + } + } + + private: + float ALIGN32_BEG fp_h_[1] ALIGN32_END; + int w_; + SeqPoolType type_; + reg64_t param_src{abi_param1}; + reg64_t param_dst{abi_param2}; + reg64_t param_attr{abi_param3}; + reg64_t reg_tmp{rax}; + + reg32_t reg32_int_h{r8d}; + reg32_t reg32_fp_h{r9d}; + + reg64_t reg_h_i{r10}; + reg64_t reg_ptr_src_i{r11}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index d00584baa081c21762774aef4cbbc714d49cd012..7d02590f2e5d82b5105132d7af716f14c661d067 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -26,6 +26,7 @@ namespace jit { const char* to_string(KernelType kt) { switch (kt) { + ONE_CASE(kNone); ONE_CASE(kVMul); ONE_CASE(kVAdd); ONE_CASE(kVAddRelu); @@ -45,12 +46,26 @@ const char* to_string(KernelType kt) { ONE_CASE(kCRFDecoding); ONE_CASE(kLayerNorm); ONE_CASE(kNCHW16CMulNC); + ONE_CASE(kSeqPool); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; } return nullptr; } + +const char* to_string(SeqPoolType tp) { + switch (tp) { + ONE_CASE(kNonePoolType); + ONE_CASE(kSum); + ONE_CASE(kAvg); + ONE_CASE(kSqrt); + default: + PADDLE_THROW("Not support type: %d, or forget to add it.", tp); + return "NOT PoolType"; + } + return nullptr; +} #undef ONE_CASE KernelType to_kerneltype(const std::string& act) { diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 412df86aa1cd94871989aef25adef803f673812b..fbf34fc4b3db49596b6be0360c00e77c12fab9b8 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -119,6 +119,7 @@ typename KernelTuples::func_type Get( } const char* to_string(KernelType kt); +const char* to_string(SeqPoolType kt); KernelType to_kerneltype(const std::string& act); @@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { << "],act_cand[" << to_string(attr.act_cand) << "]"; return os; } +inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) { + os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type[" + << to_string(attr.type) << "]"; + return os; +} } // namespace jit } // namespace operators diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index b4a2d5d47301a2fd82bf27ddfaaa31ef23e431c2..2a7697a6f253dcc2b8143d9f14a80a1cfd45996d 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,8 +41,16 @@ typedef enum { kCRFDecoding, kLayerNorm, kNCHW16CMulNC, + kSeqPool, } KernelType; +typedef enum { + kNonePoolType = 0, + kSum = 1, + kAvg, + kSqrt, +} SeqPoolType; + template struct XYZNTuples { typedef T data_type; @@ -112,6 +120,21 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +typedef struct seq_pool_attr_s { + int h, w; // h should always be the first one + SeqPoolType type; + seq_pool_attr_s() = default; + explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1) + : h(height), w(width), type(pool_type) {} +} seq_pool_attr_t; + +template +struct SeqPoolTuples { + typedef T data_type; + typedef seq_pool_attr_t attr_type; + typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 4e6a19f04fd425b920aeea49b63001941d800a73..61de38688664f83775c0c4e5aa6f7e06c3602ddb 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -42,6 +42,13 @@ size_t JitCodeKey(const gru_attr_t& attr) { (static_cast(attr.act_cand) << act_type_shift); } +template <> +size_t JitCodeKey(const seq_pool_attr_t& attr) { + size_t key = attr.w; + constexpr int pool_type_shift = 3; + return (key << pool_type_shift) + static_cast(attr.type); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 863cc720d68ce3dcfe045aa11c559a06a50909f3..f5ed2f0572176e42b774259c2b8fe9713d989417 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) +USE_JITKERNEL_MORE(kSeqPool, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index a5b088d4812b8a54e3b4fb1cb83d9e8bc7501994..5a499ac2c02aa70d2824f0d3be618e083ba10334 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -72,6 +72,26 @@ void VExp(const double* x, double* y, int n) { platform::dynload::vdExp(n, x, y); } +template <> +void VCopy(const float* x, float* y, int n) { + platform::dynload::cblas_scopy(n, x, 1, y, 1); +} + +template <> +void VCopy(const double* x, double* y, int n) { + platform::dynload::cblas_dcopy(n, x, 1, y, 1); +} + +template <> +void VAXPY(float a, const float* x, float* y, int n) { + platform::dynload::cblas_saxpy(n, a, x, 1, y, 1); +} + +template <> +void VAXPY(double a, const double* x, double* y, int n) { + platform::dynload::cblas_daxpy(n, a, x, 1, y, 1); +} + // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 template <> bool VMulKernel::UseMe(const int& d) const { @@ -103,6 +123,16 @@ bool VTanhKernel::UseMe(const int& d) const { return d > 7; } +template <> +bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + +template <> +bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { + return true; +} + #define AWALYS_USE_ME_WITH_DOUBLE(func) \ template <> \ bool func##Kernel::UseMe(const int& d) const { \ @@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); +REGISTER_MKL_KERNEL(kSeqPool, SeqPool); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index ee1031c028ff72181f504004b7cbeb9f7ee578f1..0a3816db24ccd0820cb259b40044e1f5b66665f7 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/operators/jit/kernel_base.h" @@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n); template void VExp(const T* x, T* y, int n); +template +void VCopy(const T* x, T* y, int n); + +template +void VAXPY(T a, const T* x, T* y, int n); + template void VSigmoid(const T* x, T* y, int n) { const T min = SIGMOID_THRESHOLD_MIN; @@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) { } } +template +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + VCopy(x, y, attr->w); + for (int h = 1; h != attr->h; ++h) { + VAXPY(static_cast(1), x + h * attr->w, y, attr->w); + } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast(attr->h)); + } + VScal(&scalar, y, y, attr->w); + } +} + #define DECLARE_MKL_KERNEL(name, tuples) \ template \ class name##Kernel : public KernelMore> { \ @@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples); +DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 07497b732050a7299e224531db37eb56e60ef605..0f626bb3bfd2851e3fb6ad8265169f9bb9860851 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2) USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) +USE_JITKERNEL_REFER(kSeqPool) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index d196266326b4ee668f647fa51032f6344d26e5c6..85381daa47484a4053326f04e12d583543a423e0 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); +REGISTER_REFER_KERNEL(kSeqPool, SeqPool); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 0fd1b89dfdba9f4655f649fa6d32604188c78da3..b4e9c8dd107ee844544165b1719d38754ae976bc 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -332,6 +332,28 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { } } +template +void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { + for (int w = 0; w < attr->w; ++w) { + const T* src = x + w; + T* dst = y + w; + *dst = static_cast(0); + for (int h = 0; h < attr->h; ++h) { + *dst = *dst + *src; + src += attr->w; + } + } + if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) { + T scalar = static_cast(1); + if (attr->type == SeqPoolType::kAvg) { + scalar = scalar / static_cast(attr->h); + } else { + scalar = scalar / std::sqrt(static_cast(attr->h)); + } + VScal(&scalar, y, y, attr->w); + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -370,6 +392,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); +DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index a73e2a60aeb0c1594b5072b2bffbd11cccfcdc7d..30291bfef3bc96fe2e687e5be6d782eee89496aa 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -211,6 +211,24 @@ struct TestFuncWithRefer, std::vector, std::vector, } }; +template +struct TestFuncWithRefer, std::vector, + std::vector> { + void operator()(const typename jit::SeqPoolTuples::func_type tgt, + const std::vector& x, const std::vector& yref, + const typename jit::SeqPoolTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size() % yref.size(), 0); + int w = yref.size(); + std::vector y(w); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, &attr); + ExpectEQ(y_data, yref_data, w); + } +}; + template void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { @@ -415,6 +433,31 @@ void TestGRUKernel() { } } +template +void TestSeqPoolKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + std::vector pool_types = { + jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; + for (auto type : pool_types) { + for (int w : TestSizes()) { + jit::seq_pool_attr_t attr(w, type); + for (int h : TestSizes()) { + attr.h = h; + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector x(h * w), yref(w); + RandomVec(h * w, x.data(), -2.f, 2.f); + const T* x_data = x.data(); + T* yref_data = yref.data(); + ref(x_data, yref_data, &attr); + VLOG(10) << attr; + TestAllImpls, PlaceType, std::vector, + std::vector>(attr, x, yref, attr); + } + } + } +} + template void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); @@ -569,6 +612,12 @@ TEST(JITKernel, kGRUHtPart2) { TestGRUKernel(); } +TEST(JITKernel, kSeqPool) { + namespace jit = paddle::operators::jit; + TestSeqPoolKernel(); + TestSeqPoolKernel(); +} + TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel { cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Ctype, int ldc) { - // Because the gcc 4.8 doesn't expand template parameter pack that - // appears in a lambda-expression, I can not use template parameter pack - // here. - auto cublas_call = [&]() { +// Because the gcc 4.8 doesn't expand template parameter pack that +// appears in a lambda-expression, I can not use template parameter pack +// here. #if CUDA_VERSION >= 8000 - VLOG(5) << "use_tensor_op_math: " - << (platform::TensorCoreAvailable() ? "True" : "False"); + VLOG(5) << "use_tensor_op_math: " + << (dev_ctx->tensor_core_available() ? "True" : "False"); + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc)); + }); #else - PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif } }; @@ -170,32 +162,24 @@ struct CUBlas { cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Ctype, int ldc, cudaDataType_t computeType) { - auto cublas_call = [&]() { #if CUDA_VERSION >= 8000 - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; #if CUDA_VERSION >= 9000 - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); + bool use_tensor_op_math = dev_ctx->tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); #endif // CUDA_VERSION >= 9000 + dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, - lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); + handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, + beta, C, Ctype, ldc, computeType, algo)); + }); #else - PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); -#endif - }; - -#if CUDA_VERSION >= 9000 - // NOTES: To use Tensor Core, we should change the cublas config, - // but the cublas may be hold by multi-thread. - dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); -#else - cublas_call(); + PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); #endif } }; @@ -223,9 +207,10 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CUDA_R_32F, N); } else { #endif // CUDA_VERSION >= 8000 - - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, N); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, N); + }); #if CUDA_VERSION >= 8000 } @@ -266,9 +251,12 @@ inline void Blas::GEMM( CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &h_alpha, h_B, ldb, h_A, lda, - &h_beta, h_C, N); + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, + &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, + N); + }); #endif // CUDA_VERSION >= 8000 } @@ -292,8 +280,10 @@ void Blas::GEMM(bool transA, bool transB, int M, } else { #endif // CUDA_VERSION >= 8000 - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, - &alpha, B, ldb, A, lda, &beta, C, ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, + lda, &beta, C, ldc); + }); #if CUDA_VERSION >= 8000 } @@ -311,16 +301,19 @@ inline void Blas::GEMM( cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, A, lda, &beta, C, - ldc); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, ldc); + }); } template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { - CUBlas::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::AXPY(handle, n, &alpha, x, 1, y, 1); + }); } template <> @@ -330,8 +323,9 @@ void Blas::GEMV(bool trans_a, int M, int N, T beta, T *C) const { cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; - CUBlas::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, - &beta, C, 1); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1); + }); } template <> @@ -353,28 +347,28 @@ void Blas::BatchedGEMM( #if CUDA_VERSION >= 9010 if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { - auto cublas_call = [&]() { - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; - bool use_tensor_op_math = platform::TensorCoreAvailable(); - if (use_tensor_op_math) { - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } - VLOG(5) << "use_tensor_op_math: " - << (use_tensor_op_math ? "True" : "False"); - + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( - context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, - CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, - CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); - }; - auto &dev_ctx = const_cast(context_); - dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH); + handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb, + strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc, + strideC, batchCount, CUDA_R_32F, algo)); + }); } else { #endif // CUDA_VERSION >= 9010 - CUBlas::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, - N, M, K, &alpha, B, ldb, strideB, A, lda, - strideA, &beta, C, ldc, strideC, batchCount); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, strideB, A, lda, strideA, &beta, C, + ldc, strideC, batchCount); + }); #if CUDA_VERSION >= 9010 } diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 6d491dbf1ed162ef07fda4c07e95cc57108486fd..2a47502614b9cd3df4583992669ab4bf78228181 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/sequence_pooling.h" @@ -239,15 +240,33 @@ class SequencePoolFunctor { last_pool(context, input, output); return; } - if (pooltype == "FIRST") { math::FirstSeqPoolFunctor first_pool; first_pool(context, input, output); return; } + auto lod = input.lod()[0]; + if (pooltype == "SUM") { + auto place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_cpu_place(place)); + const T* src = input.data(); + T* dst = output->mutable_data(place); + jit::seq_pool_attr_t attr( + static_cast(input.numel() / input.dims()[0]), + jit::SeqPoolType::kSum); + auto seqpool = + jit::Get, platform::CPUPlace>( + attr); + for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { + attr.h = static_cast(lod[i + 1] - lod[i]); + seqpool(src, dst, &attr); + dst += attr.w; + src += attr.h * attr.w; + } + return; + } auto& place = *context.eigen_device(); - auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { Tensor in_t = input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -258,15 +277,6 @@ class SequencePoolFunctor { auto out_e = EigenVector::Flatten(out_t); if (pooltype == "AVERAGE") { out_e.device(place) = in_e.mean(Eigen::array({{0}})); - } else if (pooltype == "SUM") { - if (h > 0) { - const T* in_data = in_t.data(); - T* out_data = out_t.mutable_data(context.GetPlace()); - blas.VCOPY(w, in_data, out_data); - for (int64_t r = 1; r != h; ++r) { - blas.AXPY(w, 1., in_data + r * w, out_data); - } - } } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); diff --git a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h index 6610380fcf432d0019f7e844fa9304e151b20efd..0c0d25d0cd1ae536618057ce80388b8eeb81c68a 100644 --- a/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h +++ b/paddle/fluid/operators/ngraph/ops/binary_unnary_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -48,4 +47,3 @@ static void BuildUnaryNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h index 15fbd58b02d2b13a8f5401f7cbe291da35748e83..8f5092963c8b79501ea68c1f521c4678977635ea 100644 --- a/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h +++ b/paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -58,4 +57,3 @@ std::shared_ptr ElementwiseScalar( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h index 5eff69e7b165fa19c775926914b7b3e8fcb043e5..406a4314f89810df192280cc97de245553d5520f 100644 --- a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -58,4 +57,3 @@ void BuildFillConstantNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mean_op.h b/paddle/fluid/operators/ngraph/ops/mean_op.h index 7fcf8f09cd346db8cf6706014e0d4573ced7a86c..4c44bc4c112f401c2707f7babd49a33f238a768f 100644 --- a/paddle/fluid/operators/ngraph/ops/mean_op.h +++ b/paddle/fluid/operators/ngraph/ops/mean_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -65,4 +64,3 @@ void BuildMeanGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/mul_op.h b/paddle/fluid/operators/ngraph/ops/mul_op.h index 9e12e5d7c3da04706907c7ae63ce8046ce667f25..4a6cbebe245f891c6c33b2116330a41d89d50e25 100644 --- a/paddle/fluid/operators/ngraph/ops/mul_op.h +++ b/paddle/fluid/operators/ngraph/ops/mul_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -131,4 +130,3 @@ static void BuildMulGradNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/scale_op.h b/paddle/fluid/operators/ngraph/ops/scale_op.h index 24ab0702aa50861b34fe1af7ccaf37d4e1dffc41..91a57d0be606373e985a30b7ac9c73648062d8e4 100644 --- a/paddle/fluid/operators/ngraph/ops/scale_op.h +++ b/paddle/fluid/operators/ngraph/ops/scale_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -38,4 +37,3 @@ void BuildScaleNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/operators/ngraph/ops/top_k_op.h b/paddle/fluid/operators/ngraph/ops/top_k_op.h index 2b7254497c0e1aab2e653e69e6461f262b929703..ea66953a125860ab1ce8309819b6c433ff32eaaa 100644 --- a/paddle/fluid/operators/ngraph/ops/top_k_op.h +++ b/paddle/fluid/operators/ngraph/ops/top_k_op.h @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_WITH_NGRAPH #pragma once #include @@ -48,4 +47,3 @@ void BuildTopKNode( } // namespace ngraphs } // namespace operators } // namespace paddle -#endif diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..122de72e15d587cf33b5d9856ac8b1243f666881 --- /dev/null +++ b/paddle/fluid/platform/cuda_helper.h @@ -0,0 +1,58 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/macros.h" + +#if CUDA_VERSION < 9000 +enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 }; +#endif + +namespace paddle { +namespace platform { + +class CublasHandleHolder { + public: + CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { + PADDLE_ENFORCE(dynload::cublasCreate(&handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream)); +#if CUDA_VERSION >= 9000 + if (math_type == CUBLAS_TENSOR_OP_MATH) { + PADDLE_ENFORCE( + dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH)); + } +#endif + } + + ~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); } + + template + inline void Call(Callback &&callback) const { + std::lock_guard guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CublasHandleHolder); + + cublasHandle_t handle_; + mutable std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 6f38dbb7a20dae4c4ea1e448c8572d98800b0213..09f3d3de54e4388f7090621a0fead96b3043d918 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH)); + + if (TensorCoreAvailable()) { +#if CUDA_VERSION >= 9000 + cublas_tensor_core_handle_.reset( + new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH)); +#endif + } + if (dynload::HasCUDNN()) { cudnn_holder_.reset(new CudnnHolder(&stream_, place)); } @@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { SetDeviceId(place_.device); Wait(); WaitStreamCallback(); - PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); + cublas_handle_.reset(); + cublas_tensor_core_handle_.reset(); eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); @@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { return eigen_device_.get(); } -cublasHandle_t CUDADeviceContext::cublas_handle() const { - return cublas_handle_; +bool CUDADeviceContext::tensor_core_available() const { + return cublas_tensor_core_handle_ != nullptr; } cudnnHandle_t CUDADeviceContext::cudnn_handle() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7e875801893f3b73f8efaf33af690f8c855beee4..c81d17380cf894631d06588c007c2e11ce5c7836 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/temporary_allocator.h" #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/gpu_info.h" @@ -209,39 +210,6 @@ class CudnnWorkspaceHandle { std::unique_ptr> guard_; }; -#if CUDA_VERSION >= 9000 -class ScopedCublasMathMode { - public: - ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode) - : handle_(handle) { - need_reset = false; - PADDLE_ENFORCE( - platform::dynload::cublasGetMathMode(handle_, &old_math_mode_), - "Failed to get old cublas math mode"); - if (old_math_mode_ != new_math_mode) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, new_math_mode), - "Failed to set old cublas math mode"); - need_reset = true; - } - } - - ~ScopedCublasMathMode() { - if (need_reset) { - PADDLE_ENFORCE( - platform::dynload::cublasSetMathMode(handle_, old_math_mode_), - "Failed to set old cublas math mode"); - } - } - - private: - cublasHandle_t handle_; - cublasMath_t old_math_mode_; - bool need_reset; -}; - -#endif - class CUDADeviceContext : public DeviceContext { public: explicit CUDADeviceContext(CUDAPlace place); @@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; - /*! \brief Return cublas handle in the device context. */ - cublasHandle_t cublas_handle() const; + /*! \brief Call cublas function safely. */ + template + inline void CublasCall(Callback&& callback) const { + cublas_handle_->Call(std::forward(callback)); + } + + /*! \brief Check whether tensor core is supported */ + bool tensor_core_available() const; + + /*! \brief Call cublas function with Tensor Core safely. If + Tensor Core is not available, use DEFAULT_MATH instead. */ + template + inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const { + if (cublas_tensor_core_handle_) { + cublas_tensor_core_handle_->Call(std::forward(callback)); + } else { + cublas_handle_->Call(std::forward(callback)); + } + } /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; @@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { template void RecordEvent(cudaEvent_t ev, Callback callback) { - std::lock_guard guard(mtx_); callback(); PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); } @@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext { void WaitStreamCallback() const { callback_manager_->Wait(); } -#if CUDA_VERSION >= 9000 - /*! \brief CublasCall may need to change cublas's config, - * but the cublas may be hold by multi-thread, so we should - * add lock here. */ - template - void CublasCall(Callback callback, cublasMath_t new_math) { - std::lock_guard guard(cublas_mtx_); - ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math); - callback(); - } -#endif - private: CUDAPlace place_; @@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr eigen_stream_; std::unique_ptr cudnn_holder_; cudaStream_t stream_; - cublasHandle_t cublas_handle_; + + std::unique_ptr cublas_handle_; + std::unique_ptr cublas_tensor_core_handle_; int compute_capability_; int runtime_version_; @@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { int multi_process_; int max_threads_per_mp_; - mutable std::mutex mtx_; - // StreamCallbackManager is thread-safe std::unique_ptr callback_manager_; - mutable std::mutex cublas_mtx_; + DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); }; template <> diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 171d2979a0218ad5e22112190a59866b3e0b617f..5b3aa98efb46b51d6c3edb6d2cbd4200bd0a35c6 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); - cublasHandle_t cublas_handle = device_context->cublas_handle(); - ASSERT_NE(nullptr, cublas_handle); - ASSERT_NE(nullptr, device_context->stream()); delete device_context; } } diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f9f3807b1567eaf0be20b522154552a8b157583f..2c17716500ababfab3216a5ec47fecca30065ff1 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -155,7 +155,7 @@ def __bootstrap__(): 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', - 'cudnn_exhaustive_search_times', 'sync_nccl_allreduce' + 'sync_nccl_allreduce' ] core.init_gflags([sys.argv[0]] +