未验证 提交 c5548178 编写于 作者: Y Yiqun Liu 提交者: GitHub

A a pass to enable the use of cudnn (#19346)

* Add a interface to enable cudnn for inference.

* Add cudnn_placement_pass.
test=develop

* Set the default value of cudnn_enabled_op_types to null.
test=develop

* Write the common basic class, placement_pass_base, to refine the codes.
test=develop

* Call EnableCUDNN in unittest.
test=develop

* Refine cudnn_placement_pass tester.

* Enable the testing of cudnn_placement_pass in inference's unittest.
test=develop

* Add the check of op kernels.
test=develop
上级 cc443675
......@@ -12,21 +12,14 @@ unset(INFER_IR_PASSES CACHE) # clear the global variable
function(pass_library TARGET DEST)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(multiValueArgs SRCS DEPS DIR)
set(targetPrefix "")
# Get optional argument
set(extraMacroArgs ${ARGN})
list(LENGTH extraMacroArgs numExtraMacroArgs)
if(numExtraMacroArgs GREATER 0)
list(GET extraMacroArgs 0 targetPrefix)
endif()
cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(targetPrefix)
cc_library(${TARGET} SRCS ${targetPrefix}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(pass_library_DIR)
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
else()
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
......@@ -44,6 +37,7 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper)
......@@ -77,22 +71,25 @@ pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
if(WITH_GPU)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
endif()
if(ANAKIN_SUBGRAPH)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif()
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn)
pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(fc_mkldnn_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
endif()
if(WITH_NGRAPH)
......@@ -121,6 +118,9 @@ cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DE
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_simplify_with_basic_ops_pass SRCS simplify_with_basic_ops_pass_tester.cc DEPS simplify_with_basic_ops_pass)
if(WITH_GPU)
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
endif()
if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif()
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
REGISTER_PASS(cudnn_placement_pass, paddle::framework::ir::CUDNNPlacementPass)
.RequirePassAttr("cudnn_enabled_op_types");
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/placement_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should use cuDNN.
*/
class CUDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "cuDNN"; }
const std::string GetAttrName() const { return "use_cudnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
return Get<std::unordered_set<std::string>>("cudnn_enabled_op_types");
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cudnn_placement_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
void RegisterOpKernel() {
static bool is_registered = false;
if (!is_registered) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
platform::CUDAPlace place = platform::CUDAPlace(0);
OpKernelType plain_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kPlain);
OpKernelType cudnn_kernel_type =
OpKernelType(proto::VarType::FP32, place, DataLayout::kAnyLayout,
LibraryType::kCUDNN);
auto fake_kernel_func = [](const ExecutionContext&) -> void {
static int num_calls = 0;
num_calls++;
};
all_kernels["conv2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["pool2d"][cudnn_kernel_type] = fake_kernel_func;
all_kernels["depthwise_conv2d"][plain_kernel_type] = fake_kernel_func;
all_kernels["relu"][plain_kernel_type] = fake_kernel_func;
is_registered = true;
}
}
void MainTest(std::initializer_list<std::string> cudnn_enabled_op_types,
unsigned expected_use_cudnn_true_count) {
// operator use_cudnn
// --------------------------------------------------
// (a,b)->concat->c -
// (c,weights,bias)->conv2d->f false
// f->relu->g -
// g->pool2d->h false
// (h,weights2,bias2)->depthwise_conv2d->k false
// k->relu->l -
Layers layers;
VarDesc* a = layers.data("a");
VarDesc* b = layers.data("b");
VarDesc* c = layers.concat(std::vector<VarDesc*>({a, b}));
VarDesc* weights_0 = layers.data("weights_0");
VarDesc* bias_0 = layers.data("bias_0");
VarDesc* f = layers.conv2d(c, weights_0, bias_0, false);
VarDesc* g = layers.relu(f);
VarDesc* h = layers.pool2d(g, false);
VarDesc* weights_1 = layers.data("weights_1");
VarDesc* bias_1 = layers.data("bias_1");
VarDesc* k = layers.depthwise_conv2d(h, weights_1, bias_1, false);
layers.relu(k);
RegisterOpKernel();
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("cudnn_placement_pass");
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>(cudnn_enabled_op_types));
graph.reset(pass->Apply(graph.release()));
unsigned use_cudnn_true_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
auto* op = node->Op();
if (op->HasAttr("use_cudnn") &&
boost::get<bool>(op->GetAttr("use_cudnn"))) {
++use_cudnn_true_count;
}
}
}
EXPECT_EQ(use_cudnn_true_count, expected_use_cudnn_true_count);
}
TEST(CUDNNPlacementPass, enable_conv2d) {
// 1 conv2d
MainTest({"conv2d"}, 1);
}
TEST(CUDNNPlacementPass, enable_relu_pool) {
// 1 conv2d + 1 pool2d
MainTest({"conv2d", "pool2d"}, 2);
}
TEST(CUDNNPlacementPass, enable_all) {
// 1 conv2d + 1 pool2d
// depthwise_conv2d doesnot have CUDNN kernel.
MainTest({}, 2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cudnn_placement_pass);
......@@ -13,39 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
namespace paddle {
namespace framework {
namespace ir {
void MKLDNNPlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
if (!graph->Has("use_mkldnn")) {
graph->Set<bool>("use_mkldnn", new bool(true));
}
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->HasAttr("use_mkldnn") || op->HasProtoAttr("use_mkldnn")) {
if (op_types_list.empty()) {
op->SetAttr("use_mkldnn", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr("use_mkldnn", true);
}
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
.RequirePassAttr("mkldnn_enabled_op_types");
......@@ -14,8 +14,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/placement_pass_base.h"
namespace paddle {
namespace framework {
......@@ -24,9 +25,15 @@ namespace ir {
/*
* Specifies which operators should use MKLDNN.
*/
class MKLDNNPlacementPass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
class MKLDNNPlacementPass : public PlacementPassBase {
private:
const std::string GetPlacementName() const { return "MKLDNN"; }
const std::string GetAttrName() const { return "use_mkldnn"; }
const std::unordered_set<std::string> GetOpTypesList() const {
return Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
}
};
} // namespace ir
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -29,6 +30,52 @@ struct Layers {
VarDesc* data(std::string name) { return lod_tensor(name); }
VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("use_cudnn", use_cudnn);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* depthwise_conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("depthwise_conv2d");
op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("use_cudnn", use_cudnn);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* pool2d(VarDesc* x, bool use_cudnn) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("pool2d");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("use_cudnn", use_cudnn);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* relu(VarDesc* x, VarDesc* out = nullptr) {
return unary_op("relu", x, out);
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
return binary_op("mul", x, y, out);
}
......@@ -52,6 +99,22 @@ struct Layers {
return out;
}
VarDesc* concat(std::vector<VarDesc*> inputs, int axis = -1) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("concat");
std::vector<std::string> input_names(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_names[i] = inputs[i]->Name();
}
op->SetInput("X", input_names);
op->SetOutput("Out", {out->Name()});
op->SetAttr("axis", axis);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
private:
VarDesc* lod_tensor(std::string name) {
auto* var = program_.MutableBlock(0)->Var(name);
......@@ -59,6 +122,19 @@ struct Layers {
return var;
}
VarDesc* unary_op(std::string type, VarDesc* x, VarDesc* out = nullptr) {
if (!out) {
out = lod_tensor(unique_name());
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("X", {x->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y,
VarDesc* out = nullptr) {
if (!out) {
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/placement_pass_base.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace framework {
namespace ir {
void PlacementPassBase::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Applies " << GetPlacementName() << " placement strategy.";
std::string attr_name = GetAttrName();
const auto& op_types_list = GetOpTypesList();
if (!graph->Has(attr_name)) {
graph->Set<bool>(attr_name, new bool(true));
}
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if ((op->HasAttr(attr_name) || op->HasProtoAttr(attr_name)) &&
IsSupport(op->Type())) {
if (op_types_list.empty()) {
op->SetAttr(attr_name, true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
op->SetAttr(attr_name, true);
}
}
}
}
}
bool PlacementPassBase::IsSupport(const std::string& op_type) const {
if (GetAttrName() == "use_cudnn") {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operators don't have kernel.
return false;
}
for (auto& kernel_pair : it->second) {
if (platform::is_gpu_place(kernel_pair.first.place_) &&
(kernel_pair.first.library_type_ == LibraryType::kCUDNN)) {
return true;
}
}
} else if (GetAttrName() == "use_mkldnn") {
return true;
}
return false;
}
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Specifies which operators should use cuDNN.
*/
class PlacementPassBase : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
virtual const std::string GetPlacementName() const = 0;
virtual const std::string GetAttrName() const = 0;
virtual const std::unordered_set<std::string> GetOpTypesList() const = 0;
private:
bool IsSupport(const std::string& op_type) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -16,7 +16,6 @@ limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
......
......@@ -64,6 +64,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
} else if (pass_name == "cudnn_placement_pass") {
pass->Set("cudnn_enabled_op_types",
new std::unordered_set<std::string>());
#ifdef PADDLE_WITH_MKLDNN
} else if (pass_name == "cpu_quantize_placement_pass") {
pass->Set("quantize_enabled_op_types",
......
......@@ -94,8 +94,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
prog_file_ = std::move(other.prog_file_);
params_file_ = std::move(other.params_file_);
// Gpu related.
// GPU related.
CP_MEMBER(use_gpu_);
CP_MEMBER(use_cudnn_);
CP_MEMBER(device_id_);
CP_MEMBER(memory_pool_init_size_mb_);
......@@ -152,6 +153,17 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
Update();
}
void AnalysisConfig::EnableCUDNN() {
#ifdef PADDLE_WITH_CUDA
use_cudnn_ = use_gpu_;
#else
LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
use_cudnn_ = false;
#endif
Update();
}
void AnalysisConfig::EnableMKLDNN() {
#ifdef PADDLE_WITH_MKLDNN
use_mkldnn_ = true;
......@@ -243,7 +255,6 @@ void AnalysisConfig::Update() {
} else {
pass_builder_.reset(new CpuPassStrategy);
}
} else {
if (use_gpu()) {
pass_builder_.reset(new GpuPassStrategy(
......@@ -262,6 +273,16 @@ void AnalysisConfig::Update() {
}
}
if (use_gpu() && use_cudnn_) {
#ifdef PADDLE_WITH_CUDA
if (!enable_ir_optim_) {
LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
} else {
pass_builder()->EnableCUDNN();
}
#endif
}
if (use_ngraph_) {
if (!enable_ir_optim_) {
LOG(ERROR)
......
......@@ -101,6 +101,13 @@ struct AnalysisConfig {
*/
float fraction_of_gpu_memory_for_pool() const;
/** Turn on CUDNN
*/
void EnableCUDNN();
/** A boolean state telling whether to use cuDNN.
*/
bool cudnn_enabled() const { return use_cudnn_; }
/** \brief Control whether to perform IR graph optimization.
*
* If turned off, the AnalysisConfig will act just like a NativeConfig.
......@@ -269,6 +276,8 @@ struct AnalysisConfig {
int device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool use_cudnn_{false};
// TensorRT related.
bool use_tensorrt_{false};
// For workspace_size, refer it from here:
......
......@@ -125,6 +125,13 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
use_gpu_ = true;
}
void GpuPassStrategy::EnableCUDNN() {
if (!use_cudnn_) {
passes_.insert(passes_.begin(), "cudnn_placement_pass");
}
use_cudnn_ = true;
}
void GpuPassStrategy::EnableMKLDNN() {
LOG(ERROR) << "GPU not support MKLDNN yet";
}
......@@ -164,6 +171,8 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
use_gpu_ = false;
}
void CpuPassStrategy::EnableCUDNN() { LOG(ERROR) << "CPU not support cuDNN"; }
void CpuPassStrategy::EnableMKLDNN() {
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -85,6 +85,10 @@ class PassStrategy : public PaddlePassBuilder {
explicit PassStrategy(const std::vector<std::string> &passes)
: PaddlePassBuilder(passes) {}
/** Enable the use of cuDNN kernel
*/
virtual void EnableCUDNN() {}
/** The MKLDNN control exists in both CPU and GPU mode, because there can be
* still some CPU kernels running in CPU mode.
*/
......@@ -124,6 +128,7 @@ class CpuPassStrategy : public PassStrategy {
virtual ~CpuPassStrategy() = default;
void EnableCUDNN() override;
void EnableNgraph() override;
void EnableMKLDNN() override;
void EnableMkldnnQuantizer() override;
......@@ -142,13 +147,18 @@ class GpuPassStrategy : public PassStrategy {
explicit GpuPassStrategy(const GpuPassStrategy &other)
: PassStrategy(other.AllPasses()) {
use_gpu_ = true;
use_cudnn_ = other.use_cudnn_;
}
void EnableCUDNN() override;
void EnableNgraph() override;
void EnableMKLDNN() override;
void EnableMkldnnQuantizer() override;
virtual ~GpuPassStrategy() = default;
protected:
bool use_cudnn_{false};
};
extern const std::vector<std::string> kTRTSubgraphPasses;
......
......@@ -32,6 +32,7 @@ TEST(AnalysisPredictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.EnableCUDNN();
config.SetModel(model_dir);
config.pass_builder()->TurnOnDebug();
......
......@@ -23,7 +23,7 @@ namespace inference {
TEST(resnet50, compare_continuous_input) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare_continuous_input(model_dir, true);
compare_continuous_input(model_dir, /* use_tensorrt */ true);
}
} // namespace inference
......
......@@ -63,6 +63,7 @@ void SetConfig<AnalysisConfig>(AnalysisConfig* config, std::string model_dir,
config->pass_builder()->DeletePass("fc_fuse_pass");
config->pass_builder()->TurnOnDebug();
} else {
config->EnableCUDNN();
config->SwitchIrOptim();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册