提交 943ad478 编写于 作者: B bingyanghuang 提交者: Tao Luo

One possible solution to add flexibility for mkldnn placement pass (#14768)

* Choose to turn on use_mkldnn attribute v1

* Fix mkldnn_op empty bug

* format change test=develop

* fix ci test=develop

* fix ci test and add test in dam test=develop

* add example to dam compare test test=develop

* review changes test=develop
上级 b1d3a1c8
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -21,9 +22,16 @@ namespace ir { ...@@ -21,9 +22,16 @@ namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy."; VLOG(3) << "Aplies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) {
if (op_types_list.empty()) {
n->Op()->SetAttr("use_mkldnn", true); n->Op()->SetAttr("use_mkldnn", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
n->Op()->SetAttr("use_mkldnn", true);
}
} }
} }
return graph; return graph;
...@@ -33,5 +41,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( ...@@ -33,5 +41,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(mkldnn_placement_pass, REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
paddle::framework::ir::MKLDNNPlacementPass); .RequirePassAttr("mkldnn_enabled_op_types");
...@@ -116,6 +116,10 @@ struct Argument { ...@@ -116,6 +116,10 @@ struct Argument {
DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses, DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
std::vector<std::string>); std::vector<std::string>);
// Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
......
...@@ -63,6 +63,11 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -63,6 +63,11 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++; pass_num++;
} }
if (pass_name == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
}
if (pass_name == "tensorrt_subgraph_pass") { if (pass_name == "tensorrt_subgraph_pass") {
PADDLE_ENFORCE(argument->tensorrt_node_teller_valid()); PADDLE_ENFORCE(argument->tensorrt_node_teller_valid());
......
...@@ -49,6 +49,10 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -49,6 +49,10 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_; cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
// For mkldnn
use_mkldnn_ = other.use_mkldnn_;
mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
...@@ -77,6 +81,10 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { ...@@ -77,6 +81,10 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_; cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
// For mkldnn
use_mkldnn_ = other.use_mkldnn_;
mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
use_tensorrt_ = other.use_tensorrt_; use_tensorrt_ = other.use_tensorrt_;
tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_;
......
...@@ -327,6 +327,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -327,6 +327,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
} }
if (config_.use_mkldnn_) {
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
}
auto passes = config_.pass_builder()->AllPasses(); auto passes = config_.pass_builder()->AllPasses();
if (!config_.enable_ir_optim) passes.clear(); if (!config_.enable_ir_optim) passes.clear();
argument_.SetIrAnalysisPasses(passes); argument_.SetIrAnalysisPasses(passes);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
// Here we include some header files with relative paths, for that in deploy, // Here we include some header files with relative paths, for that in deploy,
...@@ -53,6 +54,9 @@ struct AnalysisConfig : public NativeConfig { ...@@ -53,6 +54,9 @@ struct AnalysisConfig : public NativeConfig {
void EnableMKLDNN(); void EnableMKLDNN();
bool use_mkldnn() const { return use_mkldnn_; } bool use_mkldnn() const { return use_mkldnn_; }
void SetMKLDNNOp(std::unordered_set<std::string> op_list) {
mkldnn_enabled_op_types_ = op_list;
}
// Specify the memory buffer of program and parameter // Specify the memory buffer of program and parameter
void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size, void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size,
...@@ -64,6 +68,7 @@ struct AnalysisConfig : public NativeConfig { ...@@ -64,6 +68,7 @@ struct AnalysisConfig : public NativeConfig {
protected: protected:
bool use_tensorrt_{false}; bool use_tensorrt_{false};
bool use_mkldnn_{false}; bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_;
int tensorrt_workspace_size_; int tensorrt_workspace_size_;
int tensorrt_max_batchsize_; int tensorrt_max_batchsize_;
std::unique_ptr<PassStrategy> pass_builder_; std::unique_ptr<PassStrategy> pass_builder_;
......
...@@ -194,6 +194,8 @@ void profile(bool use_mkldnn = false) { ...@@ -194,6 +194,8 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
std::unordered_set<std::string> op_list = {"conv3d"};
cfg.SetMKLDNNOp(op_list);
} }
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
...@@ -236,6 +238,8 @@ void compare(bool use_mkldnn = false) { ...@@ -236,6 +238,8 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
std::unordered_set<std::string> op_list = {"conv3d"};
cfg.SetMKLDNNOp(op_list);
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册