提交 d0e3b240 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-dist-sparse-decay

test=develop
...@@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL ...@@ -19,3 +19,10 @@ find_package_handle_standard_args(jemalloc DEFAULT_MSG JEMALLOC_LIBRARIES JEMALL
mark_as_advanced( mark_as_advanced(
JEMALLOC_LIBRARIES JEMALLOC_LIBRARIES
JEMALLOC_INCLUDE_DIR) JEMALLOC_INCLUDE_DIR)
if (JEMALLOC_FOUND)
add_library(jemalloc::jemalloc UNKNOWN IMPORTED)
set_target_properties(jemalloc::jemalloc PROPERTIES
IMPORTED_LOCATION ${JEMALLOC_LIBRARIES}
INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}")
endif()
...@@ -2,9 +2,11 @@ if(NOT WITH_GPU) ...@@ -2,9 +2,11 @@ if(NOT WITH_GPU)
return() return()
endif() endif()
set(paddle_known_gpu_archs "30 35 50 52 60 61 70 75") set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs7 "30 35 50 52") set(paddle_known_gpu_archs7 "30 35 50 52")
set(paddle_known_gpu_archs8 "30 35 50 52 60 61") set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
set(paddle_known_gpu_archs10 "30 35 50 52 60 61 70 75")
###################################################################################### ######################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled) # A function for automatic detection of GPUs installed (if autodetection is enabled)
...@@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x ...@@ -155,6 +157,16 @@ elseif (${CUDA_VERSION} LESS 9.0) # CUDA 8.x
# warning for now. # warning for now.
list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets") list(APPEND CUDA_NVCC_FLAGS "-Wno-deprecated-gpu-targets")
add_definitions("-DPADDLE_CUDA_BINVER=\"80\"") add_definitions("-DPADDLE_CUDA_BINVER=\"80\"")
elseif (${CUDA_VERSION} LESS 10.0) # CUDA 9.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"90\"")
elseif (${CUDA_VERSION} LESS 11.0) # CUDA 10.x
set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
list(APPEND CUDA_NVCC_FLAGS "-D_MWAITXINTRIN_H_INCLUDED")
list(APPEND CUDA_NVCC_FLAGS "-D__STRICT_ANSI__")
add_definitions("-DPADDLE_CUDA_BINVER=\"100\"")
endif() endif()
include_directories(${CUDA_INCLUDE_DIRS}) include_directories(${CUDA_INCLUDE_DIRS})
......
...@@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost") ...@@ -23,11 +23,8 @@ set(BOOST_PROJECT "extern_boost")
# checked that the devtools package of CentOS 6 installs boost 1.41.0. # checked that the devtools package of CentOS 6 installs boost 1.41.0.
# So we use 1.41.0 here. # So we use 1.41.0 here.
set(BOOST_VER "1.41.0") set(BOOST_VER "1.41.0")
if((NOT DEFINED BOOST_TAR) OR (NOT DEFINED BOOST_URL)) set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
message(STATUS "use pre defined download url") set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
set(BOOST_TAR "boost_1_41_0" CACHE STRING "" FORCE)
set(BOOST_URL "http://paddlepaddledeps.cdn.bcebos.com/${BOOST_TAR}.tar.gz" CACHE STRING "" FORCE)
endif()
MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}") MESSAGE(STATUS "BOOST_TAR: ${BOOST_TAR}, BOOST_URL: ${BOOST_URL}")
......
...@@ -55,7 +55,7 @@ ExternalProject_Add( ...@@ -55,7 +55,7 @@ ExternalProject_Add(
${MKLDNN_PROJECT} ${MKLDNN_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/intel/mkl-dnn.git"
GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a" GIT_TAG "830a10059a018cd2634d94195140cf2d8790a75a"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
......
...@@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML}) ...@@ -16,6 +16,12 @@ IF(NOT ${WITH_MKLML})
return() return()
ENDIF(NOT ${WITH_MKLML}) ENDIF(NOT ${WITH_MKLML})
IF(APPLE)
MESSAGE(WARNING "Mac is not supported with MKLML in Paddle yet. Force WITH_MKLML=OFF.")
SET(WITH_MKLML OFF CACHE STRING "Disable MKLML package in MacOS" FORCE)
return()
ENDIF()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(MKLML_DST_DIR "mklml") SET(MKLML_DST_DIR "mklml")
SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install") SET(MKLML_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
...@@ -23,32 +29,24 @@ SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR}) ...@@ -23,32 +29,24 @@ SET(MKLML_INSTALL_DIR ${MKLML_INSTALL_ROOT}/${MKLML_DST_DIR})
SET(MKLML_ROOT ${MKLML_INSTALL_DIR}) SET(MKLML_ROOT ${MKLML_INSTALL_DIR})
SET(MKLML_INC_DIR ${MKLML_ROOT}/include) SET(MKLML_INC_DIR ${MKLML_ROOT}/include)
SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib) SET(MKLML_LIB_DIR ${MKLML_ROOT}/lib)
if(WIN32) SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
SET(TIME_VERSION "2019.0.1.20181227")
IF(WIN32)
SET(MKLML_VER "mklml_win_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib) SET(MKLML_LIB ${MKLML_LIB_DIR}/mklml.lib)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.lib)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/mklml.dll)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5md.dll)
else() ELSE()
SET(MKLML_VER "mklml_lnx_${TIME_VERSION}" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so) SET(MKLML_SHARED_LIB ${MKLML_LIB_DIR}/libmklml_intel.so)
SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so) SET(MKLML_SHARED_IOMP_LIB ${MKLML_LIB_DIR}/libiomp5.so)
endif() ENDIF()
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${MKLML_ROOT}/lib")
IF((NOT DEFINED MKLML_VER) OR (NOT DEFINED MKLML_URL))
MESSAGE(STATUS "use pre defined download url")
if(WIN32)
SET(MKLML_VER "mklml_win_2019.0.1.20180928" CACHE STRING "" FORCE)
SET(MKLML_URL "https://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.zip" CACHE STRING "" FORCE)
elseif(APPLE)
SET(MKLML_VER "mklml_mac_2019.0.1.20180928" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
else()
SET(MKLML_VER "mklml_lnx_2019.0.1.20180928" CACHE STRING "" FORCE)
SET(MKLML_URL "http://paddlepaddledeps.cdn.bcebos.com/${MKLML_VER}.tgz" CACHE STRING "" FORCE)
ENDIF()
endif()
SET(MKLML_PROJECT "extern_mklml") SET(MKLML_PROJECT "extern_mklml")
MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}") MESSAGE(STATUS "MKLML_VER: ${MKLML_VER}, MKLML_URL: ${MKLML_URL}")
......
...@@ -117,7 +117,7 @@ function(common_link TARGET_NAME) ...@@ -117,7 +117,7 @@ function(common_link TARGET_NAME)
endif() endif()
if (WITH_JEMALLOC) if (WITH_JEMALLOC)
target_link_libraries(${TARGET_NAME} ${JEMALLOC_LIBRARIES}) target_link_libraries(${TARGET_NAME} jemalloc::jemalloc)
endif() endif()
endfunction() endfunction()
......
...@@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -94,4 +94,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
memory_optimize_pass) memory_optimize_pass lock_free_optimize_pass)
...@@ -232,3 +232,4 @@ USE_PASS(analysis_var_pass); ...@@ -232,3 +232,4 @@ USE_PASS(analysis_var_pass);
USE_PASS(sequential_execution_pass); USE_PASS(sequential_execution_pass);
USE_PASS(all_reduce_deps_pass); USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(lock_free_optimize_pass);
...@@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) ...@@ -31,6 +31,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/lock_free_optimize_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
const char kSumGradOpName[] = "sum";
// TODO(minqiyang): only support sgd at current time, please add
// other optimizers later.
const char kOptimizerType[] = "sgd";
std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
// We could collect all weights' name from SGD, where
// W1 <- SGD(W0, Grad0)
std::unordered_set<std::string> weight_var_set;
for (auto* node : graph->Nodes()) {
if (IsOpNamed(node, kOptimizerType)) {
auto& param_out_vars = node->Op()->Output("ParamOut");
PADDLE_ENFORCE(param_out_vars.size() == 1u);
weight_var_set.insert(param_out_vars[0]);
}
}
// find all grad's merge op via weight name, where
// Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
std::unordered_set<ir::Node*> grad_sum_op_set;
for (ir::Node* node : graph->Nodes()) {
if (IsOpNamed(node, kSumGradOpName)) {
for (ir::Node* output : node->outputs) {
// strip the last grad suffix @GRAD
std::string var_name = output->Name();
const std::string suffix(kGradVarSuffix);
if (var_name != suffix && var_name.size() > suffix.size() &&
var_name.substr(var_name.size() - suffix.size()) == suffix) {
// if so then strip them off
var_name = var_name.substr(0, var_name.size() - suffix.size());
if (weight_var_set.find(var_name) != weight_var_set.end()) {
grad_sum_op_set.insert(node);
break;
}
}
}
}
}
// get the forward op and backward op pairs, where
// out <- forward(X, W)
// Grad1 <- backward(out, X')
// Grad0 <- SUM(Grad1, Grad2, Grad3 ...)
// W0 <- SGD(W1, Grad0)
for (ir::Node* node : grad_sum_op_set) {
for (ir::Node* merged_grad_var : node->outputs) {
// find the optimizers connected with sum op
if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) &&
merged_grad_var->outputs.size() == 1u) {
ir::Node* opt_node = merged_grad_var->outputs[0];
VLOG(3) << "Found opt node " << opt_node->Name();
// find the backward op connected with sum op
for (ir::Node* unmerged_grad_var : node->inputs) {
if (IsVarNameContains(unmerged_grad_var, kGradVarSuffix) &&
unmerged_grad_var->inputs.size() == 1u) {
ir::Node* backward_op = unmerged_grad_var->inputs[0];
VLOG(3) << "Found backward_op " << backward_op->Name();
// find the forward op related to the backward op
ir::Node* forward_op =
FindForwardOpViaBackwardOp(graph.get(), backward_op);
VLOG(3) << "Found forward_op " << forward_op->Name();
PADDLE_ENFORCE(forward_op);
Node* new_optimizer_node = CreateNewSGDNode(
graph.get(), forward_op, backward_op, node, opt_node);
PADDLE_ENFORCE(new_optimizer_node);
}
}
}
}
}
// Remove the sum_op and its' outputs and connected Optimizers
for (Node* sum_op : grad_sum_op_set) {
for (Node* sum_op_output : sum_op->outputs) {
for (Node* optimize_op : sum_op_output->outputs) {
if (optimize_op->NodeType() == Node::Type::kOperation &&
optimize_op->Name() == kOptimizerType) {
VLOG(3) << "remove optimize_op: " << optimize_op->Name() << "_"
<< optimize_op->id();
graph->RemoveNode(optimize_op);
}
}
VLOG(3) << "remove sum_op_output: " << sum_op_output->Name() << "_"
<< sum_op_output->id();
graph->RemoveNode(sum_op_output);
}
VLOG(3) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
graph->RemoveNode(sum_op);
}
for (auto* node : graph->Nodes()) {
for (Node* output_node : node->outputs) {
if (output_node->Name() == "sgd") {
VLOG(3) << "Node link to SGD: " << node->Name() << "_" << node->id()
<< " --> " << output_node->Name() << "_" << output_node->id();
for (Node* input_node : node->inputs) {
VLOG(3) << "SGD Input link: " << input_node->Name() << "_"
<< input_node->id() << " --> " << node->Name() << "_"
<< node->id();
}
}
}
}
return graph;
}
ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
ir::Graph* graph, ir::Node* forward_node, ir::Node* backward_node,
ir::Node* grad_sum_node, ir::Node* optimize_node) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(forward_node);
PADDLE_ENFORCE(backward_node);
PADDLE_ENFORCE(grad_sum_node);
PADDLE_ENFORCE(optimize_node);
// find the grad var node between the grad sum node and backward_node
std::vector<ir::Node*> grad_vars =
FindConnectedNode(backward_node, grad_sum_node);
ir::Node* grad_node = nullptr;
for (ir::Node* node : grad_vars) {
if (!ir::IsControlDepVar(*node)) {
grad_node = node;
}
}
PADDLE_ENFORCE(grad_node);
// create a new SGD node
OpDesc* old_desc = optimize_node->Op();
// keep with the same block between new optimizer and the old one
OpDesc new_desc(*old_desc, old_desc->Block());
new_desc.SetInput("Param", old_desc->Input("Param"));
new_desc.SetInput("LearningRate", old_desc->Input("LearningRate"));
new_desc.SetInput("Grad", std::vector<std::string>({grad_node->Name()}));
new_desc.SetOutput("ParamOut", old_desc->Output("ParamOut"));
std::vector<std::string> op_role_vars = boost::get<std::vector<std::string>>(
new_desc.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName()));
// replace the second op role var, because the grad name was
// changed in new optimizer
op_role_vars.pop_back();
op_role_vars.push_back(grad_node->Name());
new_desc.SetAttr(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(),
op_role_vars);
new_desc.SetType(kOptimizerType);
// set backward op's op role var, this will be used to
// set device_id in multi_device_pass
backward_node->Op()->SetAttr(
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), op_role_vars);
// backward_node->Op()->SetAttr(
// framework::OpProtoAndCheckerMaker::OpRoleVarAttrName(), {});
// keep with the same output nodes between new optimizer and the
// old one
Node* sgd_node = graph->CreateOpNode(&new_desc);
// change all outputs of the optimize_node to the new one
ReplaceAllDownstreamNode(optimize_node, sgd_node);
// find connected node between forward node and optimize node
// and replace the optimize node to new sgd node
std::vector<ir::Node*> forward_opt_connected_nodes =
FindConnectedNode(forward_node, optimize_node);
for (ir::Node* node : forward_opt_connected_nodes) {
ReplaceUpstreamNode(node, optimize_node, sgd_node);
}
// find connected node between backward node and optimize node
// and replace the optimize node to new sgd node
std::vector<ir::Node*> backward_opt_connected_nodes =
FindConnectedNode(backward_node, optimize_node);
for (ir::Node* node : backward_opt_connected_nodes) {
ReplaceUpstreamNode(node, optimize_node, sgd_node);
}
// SGD must have only one param and LR in
PADDLE_ENFORCE(old_desc->Input("LearningRate").size() == 1u);
PADDLE_ENFORCE(old_desc->Input("Param").size() == 1u);
// LR and weight nodes should be copied
for (Node* upstream_node : optimize_node->inputs) {
if (upstream_node->Name() == old_desc->Input("LearningRate")[0] ||
upstream_node->Name() == old_desc->Input("Param")[0]) {
ReplaceUpstreamNode(upstream_node, optimize_node, sgd_node);
}
}
VLOG(3) << "Create new opt node" << sgd_node->Name() << "_" << sgd_node->id();
return sgd_node;
}
std::vector<ir::Node*> LockFreeOptimizePass::FindConnectedNode(
ir::Node* upstream_node, ir::Node* downstream_node) const {
std::vector<ir::Node*> result;
for (ir::Node* out_node : upstream_node->outputs) {
for (ir::Node* in_node : downstream_node->inputs) {
if (in_node == out_node) {
result.push_back(in_node);
}
}
}
return result;
}
void LockFreeOptimizePass::ReplaceUpstreamNode(
ir::Node* upstream_node, ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(upstream_node);
PADDLE_ENFORCE(old_optimizer_node);
PADDLE_ENFORCE(new_optimizer_node);
// Remove the old_optimizer_node from upstream_node's outputs vector
auto& output_node_vec = upstream_node->outputs;
for (auto output_node_iter = output_node_vec.begin();
output_node_iter != output_node_vec.end();) {
if (*output_node_iter == old_optimizer_node) {
output_node_vec.erase(output_node_iter);
break;
} else {
++output_node_iter;
}
}
// Add the new_optimizer_node to upstream_node's outputs vector
output_node_vec.emplace_back(new_optimizer_node);
new_optimizer_node->inputs.emplace_back(upstream_node);
}
void LockFreeOptimizePass::ReplaceAllDownstreamNode(
ir::Node* old_optimizer_node, ir::Node* new_optimizer_node) const {
PADDLE_ENFORCE(old_optimizer_node);
PADDLE_ENFORCE(new_optimizer_node);
for (ir::Node* downstream_node : old_optimizer_node->outputs) {
// Remove the old_optimizer_node from downstream_node's inputs vector
auto& input_node_vec = downstream_node->inputs;
for (auto input_node_iter = input_node_vec.begin();
input_node_iter != input_node_vec.end();) {
if (*input_node_iter == old_optimizer_node) {
input_node_vec.erase(input_node_iter);
break;
} else {
++input_node_iter;
}
}
// Add the new_optimizer_node to downstream_node's inputs vector
input_node_vec.emplace_back(new_optimizer_node);
new_optimizer_node->outputs.emplace_back(downstream_node);
}
}
ir::Node* LockFreeOptimizePass::FindForwardOpViaBackwardOp(
ir::Graph* graph, ir::Node* backward_node) const {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(backward_node);
// strip the suffix _grad of backward_node's name
std::string forward_op_name = backward_node->Name();
const std::string suffix("_grad");
if (forward_op_name != suffix && forward_op_name.size() > suffix.size() &&
forward_op_name.substr(forward_op_name.size() - suffix.size()) ==
suffix) {
// if so then strip them off
forward_op_name =
forward_op_name.substr(0, forward_op_name.size() - suffix.size());
} else {
LOG(WARNING) << "Illegal backward node's name " << backward_node->Name()
<< " id " << backward_node->id();
return nullptr;
}
for (ir::Node* node : graph->Nodes()) {
if (node->Name() == forward_op_name) {
if (node->outputs.size() == 0u) {
// if forward_node has no output, then it has NO grad op
continue;
}
// check whether all inputs of the backward_op that ends_with @GRAD
// comes from the output of forward_op is the input of the backward_op
bool is_related_forward_node = true;
for (ir::Node* backward_input : backward_node->inputs) {
if (IsVarNameEndsWith(backward_input, kGradVarSuffix)) {
bool meets_correct_output = false;
for (ir::Node* forward_output : node->outputs) {
if (forward_output->Name() + kGradVarSuffix ==
backward_input->Name()) {
meets_correct_output = true;
break;
}
}
if (!meets_correct_output) {
is_related_forward_node = false;
break;
}
}
}
if (is_related_forward_node) {
return node;
}
}
}
return nullptr;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(lock_free_optimize_pass,
paddle::framework::ir::LockFreeOptimizePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#define PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
#include <string>
#include <vector>
#include <boost/algorithm/string/predicate.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
/*
* Remove the sum op of all gradients of the backward op.
* And remove the dependecies of the optimizer related to the
* same backward op.
*
* Before this pass:
*
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* \ /
* \ /
* sum_op
* |
* sgd_op
*
* After this pass:
* forward_op1 forward_op2
* | |
* grad_op1 grad_op2
* | |
* sgd_op1 sgd_op2
*
* sgd_op1 and sgd_op2 will update the same weight which holds the same
* memory, so we could benefits from the acceleration
*/
class LockFreeOptimizePass : public Pass {
public:
virtual ~LockFreeOptimizePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
private:
// Create a new sgd node via current optimizer node
ir::Node* CreateNewSGDNode(ir::Graph* graph, ir::Node* forward_node,
ir::Node* backward_node, ir::Node* grad_sum_node,
ir::Node* optimize_node) const;
// Replace the input weight's optimizers
void ReplaceUpstreamNode(ir::Node* upstream_node,
ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const;
// Replace the output weight's optimizers
void ReplaceAllDownstreamNode(ir::Node* old_optimizer_node,
ir::Node* new_optimizer_node) const;
// Find all weight variables in graph
bool FindAllWeightVars(ir::Graph* graph) const;
// Find the forward_op node via the backward_op node
ir::Node* FindForwardOpViaBackwardOp(ir::Graph* graph,
ir::Node* backward_node) const;
std::vector<ir::Node*> FindConnectedNode(ir::Node* upstream_node,
ir::Node* downstream_node) const;
inline bool IsOpNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kOperation && node->Name() == name;
}
inline bool IsVarNamed(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable && node->Name() == name;
}
inline bool IsVarNameEndsWith(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable &&
boost::algorithm::ends_with(node->Name(), name);
}
inline bool IsVarNameContains(ir::Node* node, const std::string& name) const {
PADDLE_ENFORCE(node);
return node->NodeType() == Node::Type::kVariable &&
node->Name().find(name) != std::string::npos;
}
inline bool IsControlDepFrom(ir::Node* ctrl_dep_node, ir::Node* node) const {
PADDLE_ENFORCE(ctrl_dep_node);
PADDLE_ENFORCE(node);
return IsControlDepVar(*ctrl_dep_node) &&
ctrl_dep_node->inputs.size() >= 1u &&
ctrl_dep_node->inputs[0] == node;
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_IR_LOCK_FREE_OPTIMIZE_PASS_H_
...@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) { ...@@ -87,11 +87,12 @@ Variable* Scope::Var(const std::string& name) {
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); SCOPE_VARS_WRITER_LOCK
auto new_name = std::to_string(reinterpret_cast<uintptr_t>(this)) + "." +
std::to_string(vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = new_name; *name = new_name;
} }
SCOPE_VARS_WRITER_LOCK
return VarInternal(new_name); return VarInternal(new_name);
} }
......
...@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder { ...@@ -105,13 +105,15 @@ struct VarIdToTypeIndexMapHolder {
} // namespace detail } // namespace detail
const std::type_index &ToTypeIndex(int var_id) { const std::type_index &VarTraitIdToTypeIndex(int var_id) {
return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id); return detail::VarIdToTypeIndexMapHolder::ToTypeIndex(var_id);
} }
const char *ToTypeName(int var_id) { return ToTypeIndex(var_id).name(); } const char *ToTypeName(int var_id) {
return VarTraitIdToTypeIndex(var_id).name();
}
int ToTypeId(const std::type_index &type) { int TypeIndexToVarTraitId(const std::type_index &type) {
return detail::VarIdToTypeIndexMapHolder::ToTypeId(type); return detail::VarIdToTypeIndexMapHolder::ToTypeId(type);
} }
......
...@@ -66,8 +66,8 @@ namespace paddle { ...@@ -66,8 +66,8 @@ namespace paddle {
namespace framework { namespace framework {
const char *ToTypeName(int var_id); const char *ToTypeName(int var_id);
const std::type_index &ToTypeIndex(int var_id); const std::type_index &VarTraitIdToTypeIndex(int var_id);
int ToTypeId(const std::type_index &type); int TypeIndexToVarTraitId(const std::type_index &type);
namespace detail { namespace detail {
......
...@@ -45,10 +45,11 @@ struct TypeIndexChecker { ...@@ -45,10 +45,11 @@ struct TypeIndexChecker {
constexpr auto kId = VarTypeTrait<Type>::kId; constexpr auto kId = VarTypeTrait<Type>::kId;
std::type_index actual_type(typeid(Type)); std::type_index actual_type(typeid(Type));
EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name())); EXPECT_EQ(std::string(ToTypeName(kId)), std::string(actual_type.name()));
EXPECT_EQ(ToTypeIndex(kId), actual_type); EXPECT_EQ(VarTraitIdToTypeIndex(kId), actual_type);
EXPECT_EQ(ToTypeId(actual_type), kId); EXPECT_EQ(TypeIndexToVarTraitId(actual_type), kId);
EXPECT_EQ(ToTypeIndex(ToTypeId(actual_type)), actual_type); EXPECT_EQ(VarTraitIdToTypeIndex(TypeIndexToVarTraitId(actual_type)),
EXPECT_EQ(ToTypeId(ToTypeIndex(kId)), kId); actual_type);
EXPECT_EQ(TypeIndexToVarTraitId(VarTraitIdToTypeIndex(kId)), kId);
EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT EXPECT_TRUE(var_id_set->count(kId) == 0); // NOLINT
EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT EXPECT_TRUE(type_index_set->count(actual_type) == 0); // NOLINT
......
...@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) { ...@@ -80,8 +80,8 @@ void TestWord2vecPrediction(const std::string& model_path) {
i++) { i++) {
LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i] LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i]
<< " result: " << result[i]; << " result: " << result[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i], EXPECT_NEAR(static_cast<float*>(outputs.front().data.data())[i], result[i],
result[i]); 1e-3);
} }
} }
......
...@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps} ...@@ -7,4 +7,5 @@ set(analysis_deps ${analysis_deps}
ir_graph_build_pass ir_graph_build_pass
ir_analysis_pass ir_analysis_pass
analysis_passes analysis_passes
subgraph_detector
CACHE INTERNAL "") CACHE INTERNAL "")
...@@ -190,6 +190,26 @@ void BenchGRUKernel() { ...@@ -190,6 +190,26 @@ void BenchGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
std::vector<T> x(h * w), y(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* y_data = y.data();
BenchAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType>(attr, x_data,
y_data, &attr);
}
}
}
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
...@@ -228,4 +248,7 @@ int main(int argc, char* argv[]) { ...@@ -228,4 +248,7 @@ int main(int argc, char* argv[]) {
BenchGRUKernel<jit::kGRUH1, T, PlaceType>(); BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>(); BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
// seq pool function
BenchSeqPoolKernel<jit::kSeqPool, T, PlaceType>();
} }
...@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1) ...@@ -26,3 +26,4 @@ USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1) USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2) USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC) USE_JITKERNEL_GEN(kNCHW16CMulNC)
USE_JITKERNEL_GEN(kSeqPool)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/seqpool.h"
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void SeqPoolJitCode::genCode() {
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8;
const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs;
int rest_num_regs = num_block % max_num_regs;
mov(reg32_int_h, dword[param_attr]);
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovups(xmm_t(1), ptr[reg_tmp + OFFSET_EXP_ONE]);
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
fild(dword[param_attr]);
fstp(dword[reg_tmp]);
vmovss(xmm_t(0), ptr[reg_tmp]);
if (type_ == SeqPoolType::kSqrt) {
vsqrtps(xmm_t(0), xmm_t(0));
}
vdivps(xmm_t(1), xmm_t(1), xmm_t(0));
vmovss(ptr[reg_tmp], xmm_t(1));
}
const int group_len = max_num_regs * block * sizeof(float);
for (int g = 0; g < num_groups; ++g) {
pool_height<ymm_t>(g * group_len, block, max_num_regs);
}
if (rest_num_regs > 0) {
pool_height<ymm_t>(num_groups * group_len, block, rest_num_regs);
}
// part of rest_w * height
const int rest = w_ % block;
pool_height_of_rest_width(rest, (w_ - rest) * sizeof(float), max_num_regs);
ret();
}
class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
public:
bool UseMe(const seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx);
}
size_t CodeSize(const seq_pool_attr_t& attr) const override {
return 96 +
((attr.w / YMM_FLOAT_BLOCK + 4 /* for rest */) *
4 /* load, mul and save */ +
256) *
8;
}
std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, 0);
PADDLE_ENFORCE_GT(attr.h, 0);
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kSeqPool, gen::SeqPoolCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class SeqPoolJitCode : public JitCode {
public:
explicit SeqPoolJitCode(const seq_pool_attr_t& attr,
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) {
LOG(FATAL) << "Only support sum pool yet ";
}
fp_h_[0] = 1.f;
this->genCode();
}
virtual const char* name() const {
std::string base = "SeqPoolJitCode";
if (type_ == SeqPoolType::kSum) {
base += "_Sum";
} else if (type_ == SeqPoolType::kAvg) {
base += "_Avg";
} else if (type_ == SeqPoolType::kSqrt) {
base += "_Sqrt";
}
base += ("_W" + std::to_string(w_));
return base.c_str();
}
void genCode() override;
protected:
template <typename JMM>
void pool_height(int w_offset, int block, int max_num_regs) {
int offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i), ptr[param_src + offset]);
offset += sizeof(float) * block;
}
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
mov(reg_ptr_src_i, reg_tmp);
for (int i = 0; i < max_num_regs; ++i) {
vmovups(JMM(i + max_num_regs), ptr[reg_ptr_src_i]);
// sum anyway
vaddps(JMM(i), JMM(i), JMM(i + max_num_regs));
add(reg_ptr_src_i, sizeof(float) * block);
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(JMM(max_num_regs), ptr[reg_tmp]);
}
offset = w_offset;
for (int i = 0; i < max_num_regs; ++i) {
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
vmulps(JMM(i), JMM(i), JMM(max_num_regs));
}
vmovups(ptr[param_dst + offset], JMM(i));
offset += sizeof(float) * block;
}
}
void pool_height_of_rest_width(int rest, int w_offset, int max_num_regs) {
const int rest_used_num_regs = load_rest(rest, w_offset, 0);
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
cmp(reg32_int_h, 1);
Label l_next_h, l_h_done;
jle(l_h_done, T_NEAR);
mov(reg_h_i, 1);
mov(reg_tmp, param_src);
add(reg_tmp, w_ * sizeof(float) + w_offset);
L(l_next_h);
{
int reg_idx = 0;
mov(reg_ptr_src_i, reg_tmp);
if (has_block4) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 4);
reg_idx++;
}
if (has_block2) {
vmovups(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
add(reg_ptr_src_i, sizeof(float) * 2);
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
inc(reg_h_i);
add(reg_tmp, w_ * sizeof(float));
cmp(reg_h_i, reg32_int_h);
jl(l_next_h, T_NEAR);
}
L(l_h_done);
// save right now
if (type_ == SeqPoolType::kAvg || type_ == SeqPoolType::kSqrt) {
mov(reg_tmp, reinterpret_cast<size_t>(fp_h_));
vbroadcastss(xmm_t(max_num_regs), ptr[reg_tmp]);
for (int i = 0; i < rest_used_num_regs; ++i) {
vmulps(xmm_t(i), xmm_t(i), xmm_t(max_num_regs));
}
}
save_rest(rest, w_offset);
}
// return the number of used regs, use start from reg 0
int load_rest(int rest, int w_offset, const int num_shift_regs,
const int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(xmm_t(reg_idx + num_shift_regs), ptr[param_src + w_offset]);
reg_idx++;
}
return reg_idx;
}
// use reg start from 0
void save_rest(int rest, int w_offset, int reg_start = 0) {
const bool has_block4 = rest / 4 > 0;
const bool has_block2 = (rest % 4) / 2 > 0;
const bool has_block1 = (rest % 2) == 1;
int reg_idx = reg_start;
if (has_block4) {
vmovups(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 4;
reg_idx++;
}
if (has_block2) {
vmovq(ptr[param_dst + w_offset], xmm_t(reg_idx));
w_offset += sizeof(float) * 2;
reg_idx++;
}
if (has_block1) {
vmovss(ptr[param_dst + w_offset], xmm_t(reg_idx));
}
}
private:
float ALIGN32_BEG fp_h_[1] ALIGN32_END;
int w_;
SeqPoolType type_;
reg64_t param_src{abi_param1};
reg64_t param_dst{abi_param2};
reg64_t param_attr{abi_param3};
reg64_t reg_tmp{rax};
reg32_t reg32_int_h{r8d};
reg32_t reg32_fp_h{r9d};
reg64_t reg_h_i{r10};
reg64_t reg_ptr_src_i{r11};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -26,6 +26,7 @@ namespace jit { ...@@ -26,6 +26,7 @@ namespace jit {
const char* to_string(KernelType kt) { const char* to_string(KernelType kt) {
switch (kt) { switch (kt) {
ONE_CASE(kNone);
ONE_CASE(kVMul); ONE_CASE(kVMul);
ONE_CASE(kVAdd); ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu); ONE_CASE(kVAddRelu);
...@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) { ...@@ -45,12 +46,26 @@ const char* to_string(KernelType kt) {
ONE_CASE(kCRFDecoding); ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm); ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC); ONE_CASE(kNCHW16CMulNC);
ONE_CASE(kSeqPool);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
} }
return nullptr; return nullptr;
} }
const char* to_string(SeqPoolType tp) {
switch (tp) {
ONE_CASE(kNonePoolType);
ONE_CASE(kSum);
ONE_CASE(kAvg);
ONE_CASE(kSqrt);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
return "NOT PoolType";
}
return nullptr;
}
#undef ONE_CASE #undef ONE_CASE
KernelType to_kerneltype(const std::string& act) { KernelType to_kerneltype(const std::string& act) {
......
...@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get( ...@@ -119,6 +119,7 @@ typename KernelTuples::func_type Get(
} }
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
const char* to_string(SeqPoolType kt);
KernelType to_kerneltype(const std::string& act); KernelType to_kerneltype(const std::string& act);
...@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) { ...@@ -134,6 +135,11 @@ inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
<< "],act_cand[" << to_string(attr.act_cand) << "]"; << "],act_cand[" << to_string(attr.act_cand) << "]";
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
os << "height_size[" << attr.h << "],width_size[" << attr.w << "],pool_type["
<< to_string(attr.type) << "]";
return os;
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
......
...@@ -41,8 +41,16 @@ typedef enum { ...@@ -41,8 +41,16 @@ typedef enum {
kCRFDecoding, kCRFDecoding,
kLayerNorm, kLayerNorm,
kNCHW16CMulNC, kNCHW16CMulNC,
kSeqPool,
} KernelType; } KernelType;
typedef enum {
kNonePoolType = 0,
kSum = 1,
kAvg,
kSqrt,
} SeqPoolType;
template <typename T> template <typename T>
struct XYZNTuples { struct XYZNTuples {
typedef T data_type; typedef T data_type;
...@@ -112,6 +120,21 @@ struct GRUTuples { ...@@ -112,6 +120,21 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
typedef struct seq_pool_attr_s {
int h, w; // h should always be the first one
SeqPoolType type;
seq_pool_attr_s() = default;
explicit seq_pool_attr_s(int width, SeqPoolType pool_type, int height = 1)
: h(height), w(width), type(pool_type) {}
} seq_pool_attr_t;
template <typename T>
struct SeqPoolTuples {
typedef T data_type;
typedef seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
template <typename T> template <typename T>
struct CRFDecodingTuples { struct CRFDecodingTuples {
typedef T data_type; typedef T data_type;
......
...@@ -42,6 +42,13 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { ...@@ -42,6 +42,13 @@ size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
(static_cast<int>(attr.act_cand) << act_type_shift); (static_cast<int>(attr.act_cand) << act_type_shift);
} }
template <>
size_t JitCodeKey<seq_pool_attr_t>(const seq_pool_attr_t& attr) {
size_t key = attr.w;
constexpr int pool_type_shift = 3;
return (key << pool_type_shift) + static_cast<int>(attr.type);
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl) ...@@ -9,3 +9,4 @@ USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
...@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) { ...@@ -72,6 +72,26 @@ void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y); platform::dynload::vdExp(n, x, y);
} }
template <>
void VCopy<float>(const float* x, float* y, int n) {
platform::dynload::cblas_scopy(n, x, 1, y, 1);
}
template <>
void VCopy<double>(const double* x, double* y, int n) {
platform::dynload::cblas_dcopy(n, x, 1, y, 1);
}
template <>
void VAXPY<float>(float a, const float* x, float* y, int n) {
platform::dynload::cblas_saxpy(n, a, x, 1, y, 1);
}
template <>
void VAXPY<double>(double a, const double* x, double* y, int n) {
platform::dynload::cblas_daxpy(n, a, x, 1, y, 1);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <> template <>
bool VMulKernel<float>::UseMe(const int& d) const { bool VMulKernel<float>::UseMe(const int& d) const {
...@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const { ...@@ -103,6 +123,16 @@ bool VTanhKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
} }
template <>
bool SeqPoolKernel<float>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \ template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \ bool func##Kernel<double>::UseMe(const int& d) const { \
...@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal); ...@@ -135,5 +165,6 @@ REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <cmath>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
...@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n); ...@@ -35,6 +36,12 @@ void VScal(const T* a, const T* x, T* y, int n);
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n); void VExp(const T* x, T* y, int n);
template <typename T>
void VCopy(const T* x, T* y, int n);
template <typename T>
void VAXPY(T a, const T* x, T* y, int n);
template <typename T> template <typename T>
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
...@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) { ...@@ -60,6 +67,23 @@ void VTanh(const T* x, T* y, int n) {
} }
} }
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
VCopy<T>(x, y, attr->w);
for (int h = 1; h != attr->h; ++h) {
VAXPY<T>(static_cast<T>(1), x + h * attr->w, y, attr->w);
}
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<tuples<T>> { \
...@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); ...@@ -81,6 +105,8 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
} // namespace mkl } // namespace mkl
......
...@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2) ...@@ -26,3 +26,4 @@ USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC) USE_JITKERNEL_REFER(kNCHW16CMulNC)
USE_JITKERNEL_REFER(kSeqPool)
...@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm); ...@@ -47,4 +47,6 @@ REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
REGISTER_REFER_KERNEL(kSeqPool, SeqPool);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -332,6 +332,28 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) { ...@@ -332,6 +332,28 @@ void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
} }
} }
template <typename T>
void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
for (int w = 0; w < attr->w; ++w) {
const T* src = x + w;
T* dst = y + w;
*dst = static_cast<T>(0);
for (int h = 0; h < attr->h; ++h) {
*dst = *dst + *src;
src += attr->w;
}
}
if (attr->type == SeqPoolType::kAvg || attr->type == SeqPoolType::kSqrt) {
T scalar = static_cast<T>(1);
if (attr->type == SeqPoolType::kAvg) {
scalar = scalar / static_cast<T>(attr->h);
} else {
scalar = scalar / std::sqrt(static_cast<T>(attr->h));
}
VScal<T>(&scalar, y, y, attr->w);
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -370,6 +392,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples); ...@@ -370,6 +392,8 @@ DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>, ...@@ -211,6 +211,24 @@ struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::SeqPoolTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
const typename jit::SeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size() % yref.size(), 0);
int w = yref.size();
std::vector<T> y(w);
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, &attr);
ExpectEQ<T>(y_data, yref_data, w);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples, template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args> typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
...@@ -415,6 +433,31 @@ void TestGRUKernel() { ...@@ -415,6 +433,31 @@ void TestGRUKernel() {
} }
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
for (auto type : pool_types) {
for (int w : TestSizes()) {
jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) {
attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f);
const T* x_data = x.data();
T* yref_data = yref.data();
ref(x_data, yref_data, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::SeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(attr, x, yref, attr);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
...@@ -569,6 +612,12 @@ TEST(JITKernel, kGRUHtPart2) { ...@@ -569,6 +612,12 @@ TEST(JITKernel, kGRUHtPart2) {
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>(); TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
} }
TEST(JITKernel, kSeqPool) {
namespace jit = paddle::operators::jit;
TestSeqPoolKernel<jit::kSeqPool, float, paddle::platform::CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) { TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit; namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
......
...@@ -51,7 +51,7 @@ math_library(pooling) ...@@ -51,7 +51,7 @@ math_library(pooling)
math_library(selected_rows_functor DEPS selected_rows math_function blas) math_library(selected_rows_functor DEPS selected_rows math_function blas)
math_library(sequence2batch) math_library(sequence2batch)
math_library(sequence_padding) math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function) math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
......
...@@ -62,27 +62,19 @@ struct CUBlas<float> { ...@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t Atype, int lda, const void *B, cudaDataType_t Atype, int lda, const void *B,
cudaDataType_t Btype, int ldb, const float *beta, void *C, cudaDataType_t Btype, int ldb, const float *beta, void *C,
cudaDataType_t Ctype, int ldc) { cudaDataType_t Ctype, int ldc) {
// Because the gcc 4.8 doesn't expand template parameter pack that // Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack // appears in a lambda-expression, I can not use template parameter pack
// here. // here.
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (platform::TensorCoreAvailable() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasSgemmEx( PADDLE_ENFORCE(platform::dynload::cublasSgemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
lda, B, Btype, ldb, beta, C, Ctype, ldc)); beta, C, Ctype, ldc));
});
#else #else
PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> { ...@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
cudaDataType_t Btype, int ldb, const void *beta, void *C, cudaDataType_t Btype, int ldb, const void *beta, void *C,
cudaDataType_t Ctype, int ldc, cudaDataType_t Ctype, int ldc,
cudaDataType_t computeType) { cudaDataType_t computeType) {
auto cublas_call = [&]() {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
bool use_tensor_op_math = platform::TensorCoreAvailable(); bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) { if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
} }
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False"); << (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmEx( PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
dev_ctx->cublas_handle(), transa, transb, m, n, k, alpha, A, Atype, handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)); beta, C, Ctype, ldc, computeType, algo));
});
#else #else
PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0"); PADDLE_THROW("cublasGemmEx is supported on cuda >= 8.0");
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx->CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
#else
cublas_call();
#endif #endif
} }
}; };
...@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F, N); CUDA_R_32F, N);
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
&alpha, B, ldb, A, lda, &beta, C, N); lda, &beta, C, N);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F); CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, CUDA_R_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
N, M, K, &h_alpha, h_B, ldb, h_A, lda, context_.CublasCall([&](cublasHandle_t handle) {
&h_beta, h_C, N); CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C,
N);
});
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
...@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M, ...@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
} else { } else {
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, context_.CublasCall([&](cublasHandle_t handle) {
&alpha, B, ldb, A, lda, &beta, C, ldc); CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A,
lda, &beta, C, ldc);
});
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
} }
...@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, A, lda, &beta, C, CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
ldc); B, ldb, A, lda, &beta, C, ldc);
});
} }
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x, void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
T *y) const { T *y) const {
CUBlas<T>::AXPY(context_.cublas_handle(), n, &alpha, x, 1, y, 1); context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
});
} }
template <> template <>
...@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N, ...@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T beta, T *C) const { T beta, T *C) const {
cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cuTransA = !trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
CUBlas<T>::GEMV(context_.cublas_handle(), cuTransA, N, M, &alpha, A, N, B, 1, context_.CublasCall([&](cublasHandle_t handle) {
&beta, C, 1); CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
});
} }
template <> template <>
...@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM( ...@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) { if (FLAGS_enable_cublas_tensor_op_math && std::is_same<T, float>::value) {
auto cublas_call = [&]() { cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; bool use_tensor_op_math = context_.tensor_core_available();
bool use_tensor_op_math = platform::TensorCoreAvailable(); if (use_tensor_op_math) {
if (use_tensor_op_math) { algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
algo = CUBLAS_GEMM_DFALT_TENSOR_OP; }
} VLOG(5) << "use_tensor_op_math: "
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
<< (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx( PADDLE_ENFORCE(platform::dynload::cublasGemmStridedBatchedEx(
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, handle, cuTransB, cuTransA, N, M, K, &alpha, B, CUDA_R_32F, ldb,
CUDA_R_32F, ldb, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, strideB, A, CUDA_R_32F, lda, strideA, &beta, C, CUDA_R_32F, ldc,
CUDA_R_32F, ldc, strideC, batchCount, CUDA_R_32F, algo)); strideC, batchCount, CUDA_R_32F, algo));
}; });
auto &dev_ctx = const_cast<platform::CUDADeviceContext &>(context_);
dev_ctx.CublasCall(cublas_call, CUBLAS_TENSOR_OP_MATH);
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
CUBlas<T>::GEMM_STRIDED_BATCH(context_.cublas_handle(), cuTransB, cuTransA, context_.CublasCall([&](cublasHandle_t handle) {
N, M, K, &alpha, B, ldb, strideB, A, lda, CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
strideA, &beta, C, ldc, strideC, batchCount); B, ldb, strideB, A, lda, strideA, &beta, C,
ldc, strideC, batchCount);
});
#if CUDA_VERSION >= 9010 #if CUDA_VERSION >= 9010
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/sequence_pooling.h" #include "paddle/fluid/operators/math/sequence_pooling.h"
...@@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -239,15 +240,33 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
last_pool(context, input, output); last_pool(context, input, output);
return; return;
} }
if (pooltype == "FIRST") { if (pooltype == "FIRST") {
math::FirstSeqPoolFunctor<T> first_pool; math::FirstSeqPoolFunctor<T> first_pool;
first_pool(context, input, output); first_pool(context, input, output);
return; return;
} }
auto lod = input.lod()[0]; auto lod = input.lod()[0];
if (pooltype == "SUM") {
auto place = context.GetPlace();
PADDLE_ENFORCE(platform::is_cpu_place(place));
const T* src = input.data<T>();
T* dst = output->mutable_data<T>(place);
jit::seq_pool_attr_t attr(
static_cast<int>(input.numel() / input.dims()[0]),
jit::SeqPoolType::kSum);
auto seqpool =
jit::Get<jit::kSeqPool, jit::SeqPoolTuples<T>, platform::CPUPlace>(
attr);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
attr.h = static_cast<int>(lod[i + 1] - lod[i]);
seqpool(src, dst, &attr);
dst += attr.w;
src += attr.h * attr.w;
}
return;
}
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
Tensor in_t = Tensor in_t =
input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1])); input.Slice(static_cast<int>(lod[i]), static_cast<int>(lod[i + 1]));
...@@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> { ...@@ -258,15 +277,6 @@ class SequencePoolFunctor<platform::CPUDeviceContext, T> {
auto out_e = EigenVector<T>::Flatten(out_t); auto out_e = EigenVector<T>::Flatten(out_t);
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}})); out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
} else if (pooltype == "SUM") {
if (h > 0) {
const T* in_data = in_t.data<T>();
T* out_data = out_t.mutable_data<T>(context.GetPlace());
blas.VCOPY(w, in_data, out_data);
for (int64_t r = 1; r != h; ++r) {
blas.AXPY(w, 1., in_data + r * w, out_data);
}
}
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) / out_e.device(place) = in_e.sum(Eigen::array<int, 1>({{0}})) /
std::sqrt(static_cast<T>(h)); std::sqrt(static_cast<T>(h));
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -48,4 +47,3 @@ static void BuildUnaryNode( ...@@ -48,4 +47,3 @@ static void BuildUnaryNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar( ...@@ -58,4 +57,3 @@ std::shared_ptr<ngraph::Node> ElementwiseScalar(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -58,4 +57,3 @@ void BuildFillConstantNode( ...@@ -58,4 +57,3 @@ void BuildFillConstantNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <functional> #include <functional>
...@@ -65,4 +64,3 @@ void BuildMeanGradNode( ...@@ -65,4 +64,3 @@ void BuildMeanGradNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -131,4 +130,3 @@ static void BuildMulGradNode( ...@@ -131,4 +130,3 @@ static void BuildMulGradNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -38,4 +37,3 @@ void BuildScaleNode( ...@@ -38,4 +37,3 @@ void BuildScaleNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
...@@ -48,4 +47,3 @@ void BuildTopKNode( ...@@ -48,4 +47,3 @@ void BuildTopKNode(
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 };
#endif
namespace paddle {
namespace platform {
class CublasHandleHolder {
public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
PADDLE_ENFORCE(dynload::cublasCreate(&handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(handle_, stream));
#if CUDA_VERSION >= 9000
if (math_type == CUBLAS_TENSOR_OP_MATH) {
PADDLE_ENFORCE(
dynload::cublasSetMathMode(handle_, CUBLAS_TENSOR_OP_MATH));
}
#endif
}
~CublasHandleHolder() { PADDLE_ENFORCE(dynload::cublasDestroy(handle_)); }
template <typename Callback>
inline void Call(Callback &&callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CublasHandleHolder);
cublasHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
...@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) ...@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place); eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); cublas_handle_.reset(new CublasHandleHolder(stream_, CUBLAS_DEFAULT_MATH));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
if (TensorCoreAvailable()) {
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_.reset(
new CublasHandleHolder(stream_, CUBLAS_TENSOR_OP_MATH));
#endif
}
if (dynload::HasCUDNN()) { if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place)); cudnn_holder_.reset(new CudnnHolder(&stream_, place));
} }
...@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId(place_.device); SetDeviceId(place_.device);
Wait(); Wait();
WaitStreamCallback(); WaitStreamCallback();
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_)); cublas_handle_.reset();
cublas_tensor_core_handle_.reset();
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
...@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { ...@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return eigen_device_.get(); return eigen_device_.get();
} }
cublasHandle_t CUDADeviceContext::cublas_handle() const { bool CUDADeviceContext::tensor_core_available() const {
return cublas_handle_; return cublas_tensor_core_handle_ != nullptr;
} }
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h" #include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle { ...@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std::unique_ptr<std::lock_guard<std::mutex>> guard_; std::unique_ptr<std::lock_guard<std::mutex>> guard_;
}; };
#if CUDA_VERSION >= 9000
class ScopedCublasMathMode {
public:
ScopedCublasMathMode(cublasHandle_t handle, cublasMath_t new_math_mode)
: handle_(handle) {
need_reset = false;
PADDLE_ENFORCE(
platform::dynload::cublasGetMathMode(handle_, &old_math_mode_),
"Failed to get old cublas math mode");
if (old_math_mode_ != new_math_mode) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, new_math_mode),
"Failed to set old cublas math mode");
need_reset = true;
}
}
~ScopedCublasMathMode() {
if (need_reset) {
PADDLE_ENFORCE(
platform::dynload::cublasSetMathMode(handle_, old_math_mode_),
"Failed to set old cublas math mode");
}
}
private:
cublasHandle_t handle_;
cublasMath_t old_math_mode_;
bool need_reset;
};
#endif
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
explicit CUDADeviceContext(CUDAPlace place); explicit CUDADeviceContext(CUDAPlace place);
...@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext { ...@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */ /*! \brief Return eigen device in the device context. */
Eigen::GpuDevice* eigen_device() const; Eigen::GpuDevice* eigen_device() const;
/*! \brief Return cublas handle in the device context. */ /*! \brief Call cublas function safely. */
cublasHandle_t cublas_handle() const; template <typename Callback>
inline void CublasCall(Callback&& callback) const {
cublas_handle_->Call(std::forward<Callback>(callback));
}
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template <typename Callback>
inline void TensorCoreCublasCallIfAvailable(Callback&& callback) const {
if (cublas_tensor_core_handle_) {
cublas_tensor_core_handle_->Call(std::forward<Callback>(callback));
} else {
cublas_handle_->Call(std::forward<Callback>(callback));
}
}
/*! \brief Return cudnn handle in the device context. */ /*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const; cudnnHandle_t cudnn_handle() const;
...@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::mutex> guard(mtx_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void WaitStreamCallback() const { callback_manager_->Wait(); } void WaitStreamCallback() const { callback_manager_->Wait(); }
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template <typename Callback>
void CublasCall(Callback callback, cublasMath_t new_math) {
std::lock_guard<std::mutex> guard(cublas_mtx_);
ScopedCublasMathMode scoped_cublas_math(cublas_handle_, new_math);
callback();
}
#endif
private: private:
CUDAPlace place_; CUDAPlace place_;
...@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext { ...@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
std::unique_ptr<CudnnHolder> cudnn_holder_; std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_; cudaStream_t stream_;
cublasHandle_t cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
...@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext { ...@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int multi_process_; int multi_process_;
int max_threads_per_mp_; int max_threads_per_mp_;
mutable std::mutex mtx_;
// StreamCallbackManager is thread-safe // StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_; std::unique_ptr<StreamCallbackManager> callback_manager_;
mutable std::mutex cublas_mtx_; DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
}; };
template <> template <>
......
...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) { ...@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, cudnn_handle);
cublasHandle_t cublas_handle = device_context->cublas_handle();
ASSERT_NE(nullptr, cublas_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
} }
......
...@@ -155,7 +155,7 @@ def __bootstrap__(): ...@@ -155,7 +155,7 @@ def __bootstrap__():
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic', 'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit', 'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus', 'cudnn_exhaustive_search', 'memory_optimize_debug', 'selected_gpus',
'cudnn_exhaustive_search_times', 'sync_nccl_allreduce' 'sync_nccl_allreduce'
] ]
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册