diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index bd27b1f5f34475db793d643f2d12508e0aea631e..05a8e8f1b5e3e33ae73047176b3b54536b77a22d 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -213,6 +213,11 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool); + DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool); + DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int); + DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int); + DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int); + DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter, std::vector); DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index a4e263e2f464c4021b049093c49ddaecb056284f..06d48a536664486043c6615f16b442b76d818bb7 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -166,6 +166,11 @@ void IRPassManager::CreatePasses(Argument *argument, // run fp16. pass->Set("disable_trt_plugin_fp16", new bool(argument->disable_trt_plugin_fp16())); + } else if (pass_name == "dlnne_subgraph_pass") { + pass->Set("min_subgraph_size", + new int(argument->dlnne_min_subgraph_size())); + pass->Set("program", + new framework::ProgramDesc *(&argument->main_program())); } if (pass_name == "lite_subgraph_pass") { bool enable_int8 = diff --git a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt index e35178428cc7bae7f5795e2a4652b808956f6776..330f7a99847344f7359a29e26efac71e969bf06d 100644 --- a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt @@ -20,3 +20,15 @@ if (WITH_LITE) set(INFER_IR_PASSES ${INFER_IR_PASSES} lite_subgraph_pass CACHE INTERNAL "") cc_test(lite_subgraph_pass_tester SRCS lite_subgraph_pass_tester.cc DEPS lite_subgraph_pass gtest glog) endif() + +MESSAGE("WITH_DLNNE:${WITH_DLNNE}") +if(WITH_DLNNE) + cc_library(dlnne_subgraph_pass SRCS dlnne_subgraph_pass.cc DEPS ${analysis_deps} subgraph_util) + set(analysis_deps ${analysis_deps} + subgraph_util dlnne_subgraph_pass + CACHE INTERNAL "") + + set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp) + file(APPEND ${pass_file} "USE_PASS(dlnne_subgraph_pass);\n") + set(INFER_IR_PASSES ${INFER_IR_PASSES} dlnne_subgraph_pass CACHE INTERNAL "") +endif() diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h b/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h new file mode 100644 index 0000000000000000000000000000000000000000..ae977c1403a8793b0611496702515f1df952d5a1 --- /dev/null +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h @@ -0,0 +1,21 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace paddle { +namespace inference { + +int RegisterPyFunc(const std::string& name, void* pfn); +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f789139af9bfc35841f284d043a2c86f5803e93 --- /dev/null +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc @@ -0,0 +1,351 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h" +#include "paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace inference { + +int (*PyConvertGraph)(const char *graph_name); + +int RegisterPyFunc(const std::string &name, void *pfn) { + if (name.compare("convert_graph") == 0) { + PyConvertGraph = reinterpret_cast(pfn); + } + + return 0; +} +int ConvertGraph(std::string graph_name) { + LOG(INFO) << "starting doing convert_graph"; + + PyConvertGraph(graph_name.c_str()); + + return 0; +} + +namespace analysis { + +using framework::ir::Node; + +void analysis::DlnneSubgraphPass::ApplyImpl(framework::ir::Graph *graph) const { + static std::unordered_set teller_set{ + "mul", "matmul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "hard_swish", "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", + "elementwise_add", "elementwise_mul", "dropout", "prelu", + "conv2d_transpose", "leaky_relu", + // "fc", + "shuffle_channel", "swish", "split", + // "instance_norm", + "gelu", + // "layer_norm", + // "scale", + // "stack", + "relu6", "reshape2", "transpose2", "concat", "slice", + }; + + framework::ir::FusePassBase::Init("dlnne_subgraph_pass", graph); + + auto teller = [&](const framework::ir::Node *node) { + if (!node->IsOp() || !node->Op()) return false; + return teller_set.find(node->Op()->Type()) != teller_set.end(); + }; + + framework::ir::SubGraphFuser fuser( + graph, teller, Get("min_subgraph_size") /*min subgraph size*/, + "dlnne_engine"); + fuser(); + + std::vector graph_param_names = + ExtractParameters(graph->Nodes()); + // those parameter already exist in dlnne, and should not have another copy in + // fluid. + std::vector repetitive_params; + + for (auto *node : graph->Nodes()) { + if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) { + CreateDlnneOp(node, graph, graph_param_names, &repetitive_params); + + std::unordered_set nodes2remove( + framework::ir::Agent(node).subgraph()->begin(), + framework::ir::Agent(node).subgraph()->end()); + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); + } + } + + std::unordered_set nodes2remove; + for (auto *node : graph->Nodes()) { + if (node->IsOp() && framework::ir::Agent(node).deleted()) { + nodes2remove.insert(node); + } + } + framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); +} + +std::string GenerateEngineKey(const std::set &engine_inputs, + const std::set &engine_outputs, + const std::string &predictor_id) { + std::string engine_hash_key = ""; + for (auto name : engine_inputs) { + engine_hash_key += name; + } + for (auto name : engine_outputs) { + engine_hash_key += name; + } + engine_hash_key += predictor_id; + auto engine_key = std::to_string(std::hash()(engine_hash_key)); + return engine_key; +} +std::string replace_name(std::string name, const char *raw, + const char *new_char) { + std::string r_name = name; + int pos = r_name.find(raw); + while (pos >= 0) { + r_name = r_name.replace(pos, 1, new_char); + pos = r_name.find(raw); + } + return r_name; +} + +void DlnneSubgraphPass::CreateDlnneOp( + framework::ir::Node *node, framework::ir::Graph *graph, + const std::vector &graph_params, + std::vector *repetitive_params) const { + auto *op_desc = node->Op(); + auto &subgraph = *framework::ir::Agent(node).subgraph(); + PADDLE_ENFORCE_EQ(subgraph.empty(), false, + platform::errors::PreconditionNotMet( + "The subgraph should not be empty.")); + + // A fake block desc. + framework::proto::BlockDesc block_proto; + framework::BlockDesc block_desc(nullptr, &block_proto); + block_desc.Proto()->set_parent_idx(-1); + block_desc.Proto()->set_idx(0); + LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; + // for debug + framework::ProgramDesc tmp_dump_program_desc; + auto *tmp_dump_main_block = tmp_dump_program_desc.MutableBlock(0); + + std::unordered_map name_var_desc; + std::set name_var_input_nodes; + std::set name_var_output_nodes; + std::set name_ops; + + for (auto *node : subgraph) { + auto *op = block_desc.AppendOp(); + *op->Proto() = *node->Op()->Proto(); + + // debug + { + name_ops.insert(node->Name()); + auto *tmp_dump_new_block_op = tmp_dump_main_block->AppendOp(); + + framework::OpDesc op_desc; + op_desc.CopyFrom(*node->Op()); + + for (auto argument_name : op_desc.InputArgumentNames()) { + if (std::count(graph_params.begin(), graph_params.end(), + argument_name) > 0) { + op_desc.Rename(argument_name, replace_name(argument_name, "/", ".")); + } + } + for (auto argument_name : op_desc.OutputArgumentNames()) { + if (std::count(graph_params.begin(), graph_params.end(), + argument_name) > 0) { + op_desc.Rename(argument_name, replace_name(argument_name, "/", ".")); + } + } + *tmp_dump_new_block_op->Proto() = *op_desc.Proto(); + + for (auto *x : node->inputs) { + if (x->IsVar()) { + name_var_desc[x->Name()] = x->Var(); + } + if (std::count(graph_params.begin(), graph_params.end(), x->Name()) == + 0) + name_var_input_nodes.insert(x->Name()); + } + + for (auto *x : node->outputs) { + if (x->IsVar()) { + name_var_desc[x->Name()] = x->Var(); + } + if (std::count(graph_params.begin(), graph_params.end(), x->Name()) == + 0) + name_var_output_nodes.insert(x->Name()); + } + } + } + std::set valid_input_names; + std::set valid_output_names; + for (auto name : name_var_output_nodes) { + if (name_var_input_nodes.find(name) == name_var_input_nodes.end()) { + valid_output_names.insert(name); + } + } + + for (auto name : name_var_input_nodes) { + if (name_var_output_nodes.find(name) == name_var_output_nodes.end()) { + valid_input_names.insert(name); + } + } + + // Then, we will use the input_names_with_id and output_names_with_id to + // generate the engine key. + // So, We use set instead of unordered_set here to ensure that the engine key + // is unique. + std::set input_names; + std::set input_names_with_id; + std::vector params; + // if we delete fluid copy of params shared by more than 1 ops, there will be + // problem, so we filter them out. + + // The node->inputs contains input tensors and parameters. + for (auto *x : node->inputs) { + input_names.insert(x->Name()); + input_names_with_id.insert(x->Name() + std::to_string(x->id())); + if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) { + params.push_back(x->Name()); + } + } + + std::set output_names; + std::set output_names_with_id; + std::vector origin_output_dims; + for (auto *x : node->outputs) { + origin_output_dims.push_back(x->Var()->GetShape().size()); + output_names.insert(x->Name()); + output_names_with_id.insert(x->Name() + std::to_string(x->id())); + } + + std::unordered_map output_name_map; + std::unordered_map graph_var_map; + + for (framework::ir::Node *node : graph->Nodes()) { + if (node->IsVar() && node->Var()) { + graph_var_map[node->Name()] = node; + } + } + + // Set attrs + op_desc->SetType("dlnne_engine"); + op_desc->SetInput("Xs", std::vector(valid_input_names.begin(), + valid_input_names.end())); + + op_desc->SetOutput("Ys", std::vector(valid_output_names.begin(), + valid_output_names.end())); + + op_desc->SetAttr("parameters", params); + auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id, + std::to_string(0)); + op_desc->SetAttr("engine_key", engine_key); + auto *scope = param_scope(); + + { + std::set input_names; + + for (auto name : name_var_input_nodes) { + if (name_var_output_nodes.find(name) == name_var_output_nodes.end()) { + input_names.insert(name); + } + } + + // add feed to subgraph: + int input_idx = 0; + for (auto input_name : input_names) { + auto *feed0 = tmp_dump_main_block->AppendOp(); + feed0->SetType("feed"); + feed0->SetInput("X", {"feed"}); + feed0->SetOutput("Out", {input_name}); + feed0->SetAttr("col", input_idx); + input_idx++; + } + // add fetch to subgraph: + int output_idx = 0; + for (auto output_name : valid_output_names) { + auto *fetch0 = tmp_dump_main_block->AppendOp(); + fetch0->SetType("fetch"); + fetch0->SetInput("X", {output_name}); + fetch0->SetOutput("Out", {"out"}); + fetch0->SetAttr("col", output_idx); + output_idx++; + } + + mkdir("./dump", 0777); + std::string dir_name = "./dump/" + engine_key; + mkdir(dir_name.c_str(), 0777); + ofstream m_stream; + m_stream.open(dir_name + "/__model__", ios::out); + + VLOG(4) << "name_var_desc size:" << name_var_desc.size(); + + for (auto &kv : name_var_desc) { + auto *new_add_var = tmp_dump_main_block->Proto()->add_vars(); + *new_add_var = *kv.second->Proto(); + auto *variable_tmp = scope->FindVar(kv.first); + if (variable_tmp != nullptr) { + *new_add_var->mutable_name() = replace_name(kv.first, "/", "."); + new_add_var->set_persistable(true); + } else { + new_add_var->set_persistable(false); + } + } + + for (auto param_name : params) { + auto *var = scope->FindVar(param_name); + if (var != nullptr) { + auto *var_t = var->GetMutable(); + ofstream p_stream; + p_stream.open(dir_name + "/" + replace_name(param_name, "/", "."), + ios::out); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(var_t->place()); + framework::SerializeToStream(p_stream, *var_t, dev_ctx); + p_stream.close(); + } + } + + std::string model; + + tmp_dump_program_desc.Proto()->SerializeToString(&model); + m_stream << model; + m_stream.close(); + + op_desc->SetBlockAttr("sub_block", tmp_dump_main_block); + op_desc->SetAttr("subgraph", model); + op_desc->Flush(); + + ConvertGraph(engine_key); + } +} + +} // namespace analysis +} // namespace inference +} // namespace paddle + +REGISTER_PASS(dlnne_subgraph_pass, + paddle::inference::analysis::DlnneSubgraphPass); diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..5a1d2506fdb09b5ecb63f8f922490eb4c8c01e2d --- /dev/null +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h @@ -0,0 +1,55 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { +class Graph; +class Node; +} // namespace ir +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { + +int ConvertGraph(std::string graph_name); + +namespace analysis { + +class DlnneSubgraphPass : public framework::ir::FusePassBase { + public: + void ApplyImpl(framework::ir::Graph *graph) const override; + + private: + void CleanIntermediateOutputs(framework::ir::Node *node); + void CreateDlnneOp(framework::ir::Node *x, framework::ir::Graph *graph, + const std::vector &graph_params, + std::vector *repetitive_params) const; +}; + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 0622fb27d9e38c87a98fcb86da64bdb21570e67d..7e874b94decbf6053f0882d5d22825584c4fc496 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -26,6 +26,7 @@ namespace paddle { struct MkldnnQuantizerConfig; extern const std::vector kTRTSubgraphPasses; +extern const std::vector kDlnneSubgraphPasses; extern const std::vector kLiteSubgraphPasses; PassStrategy *AnalysisConfig::pass_builder() const { @@ -134,6 +135,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); CP_MEMBER(trt_use_oss_); + // Dlnne related + CP_MEMBER(use_dlnne_); + CP_MEMBER(dlnne_min_subgraph_size_); // MKLDNN related. CP_MEMBER(use_mkldnn_); CP_MEMBER(mkldnn_enabled_op_types_); @@ -211,6 +215,21 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { pass_builder_->DeletePass(ps); } } + if (use_dlnne_) { + auto all_passes = kDlnneSubgraphPasses; + auto other_passes = other.pass_builder()->AllPasses(); + // We should sort them, because the user may call the SwitchIrDebug + // interface, which will change the pass. + std::sort(all_passes.begin(), all_passes.end()); + std::sort(other_passes.begin(), other_passes.end()); + std::vector deleted_passes; + std::set_difference(all_passes.begin(), all_passes.end(), + other_passes.begin(), other_passes.end(), + std::inserter(deleted_passes, deleted_passes.begin())); + for (auto ps : deleted_passes) { + pass_builder_->DeletePass(ps); + } + } } void AnalysisConfig::EnableCUDNN() { @@ -309,6 +328,12 @@ void AnalysisConfig::EnableTensorRtEngine( #endif } +void AnalysisConfig::EnableDlnne(int min_subgraph_size) { + use_dlnne_ = true; + dlnne_min_subgraph_size_ = min_subgraph_size; + Update(); +} + void AnalysisConfig::SetTRTDynamicShapeInfo( std::map> min_input_shape, std::map> max_input_shape, @@ -383,6 +408,14 @@ void AnalysisConfig::Update() { pass_builder()->AppendPass(pass); } } + LOG(INFO) << "use_dlnne_:" << use_dlnne_ << std::endl; + if (use_dlnne_) { + pass_builder()->ClearPasses(); + for (const auto &pass : kDlnneSubgraphPasses) { + pass_builder()->AppendPass(pass); + } + } + if (use_gpu() && use_cudnn_) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (!enable_ir_optim_) { @@ -479,6 +512,9 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_max_batchsize_; ss << tensorrt_min_subgraph_size_; + ss << use_dlnne_; + ss << dlnne_min_subgraph_size_; + for (auto &op : trt_disabled_ops_) ss << op.c_str(); ss << ";"; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 4b6c746d57525ab5c7289afb4e52dc78b9995b8a..698cbea5eb83b775abea3d84e03a77cd9b2a72c7 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -537,6 +537,12 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_); } + if (config_.dlnne_enabled()) { + LOG(INFO) << "Dlnne subgraph is enabled"; + argument_.SetUseDlnne(true); + argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_); + } + if (config_.lite_engine_enabled()) { argument_.SetCpuMathLibraryNumThreads( config_.cpu_math_library_num_threads()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index e492b32cb6cbefcc121b616450170e5cc22bb913..446d6770f6399940754f176c1a0cc1af14ae72db 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -360,6 +360,9 @@ struct PD_INFER_DECL AnalysisConfig { /// bool tensorrt_dla_enabled() { return trt_use_dla_; } + void EnableDlnne(int min_subgraph_size = 3); + bool dlnne_enabled() const { return use_dlnne_; } + /// /// \brief Turn on the usage of Lite sub-graph engine. /// @@ -627,6 +630,10 @@ struct PD_INFER_DECL AnalysisConfig { std::vector trt_disabled_ops_{}; bool disable_trt_plugin_fp16_{false}; + // dlnne related. + bool use_dlnne_{false}; + int dlnne_min_subgraph_size_{3}; + // memory reuse related. bool enable_memory_optim_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1d77ddaf73ef700e15330e40b356f2b03fb2401e..2b7333edae0dae1f0313bf71fc824c922e20b84d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -110,6 +110,15 @@ const std::vector kTRTSubgraphPasses({ "transpose_flatten_concat_fuse_pass", }); +const std::vector kDlnneSubgraphPasses({ + "is_test_pass", // + "simplify_with_basic_ops_pass", // + "conv_bn_fuse_pass", // + "depthwise_conv_bn_fuse_pass", // + "shuffle_channel_detect_pass", // + "dlnne_subgraph_pass", // +}); + const std::vector kLiteSubgraphPasses({ #ifdef PADDLE_WITH_LITE "lite_subgraph_pass", diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index a725ebab35eadaaaab76a3a7c4580f95b64d827d..d7556b50031b7d63b75e1e0d12fa173f8fe9fd33 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -242,6 +242,9 @@ class PD_INFER_DECL XpuPassStrategy final : public PassStrategy { /// \brief List of tensorRT subgraph passes. PD_INFER_DECL extern const std::vector kTRTSubgraphPasses; +/// \brief List of dlnne subgraph passes. +PD_INFER_DECL extern const std::vector kDlnneSubgraphPasses; + /// \brief List of lite subgraph passes. PD_INFER_DECL extern const std::vector kLiteSubgraphPasses; diff --git a/paddle/fluid/inference/capi/pd_config.cc b/paddle/fluid/inference/capi/pd_config.cc index 231639667244d8646faa94ca453227de923c5814..9bb52ba57802512f393c23f957cc38ddabb878b1 100644 --- a/paddle/fluid/inference/capi/pd_config.cc +++ b/paddle/fluid/inference/capi/pd_config.cc @@ -260,6 +260,22 @@ bool PD_TensorrtEngineEnabled(const PD_AnalysisConfig* config) { return config->config.tensorrt_engine_enabled(); } +void PD_EnableDlnne(PD_AnalysisConfig* config, int min_subgraph_size) { + PADDLE_ENFORCE_NOT_NULL( + config, + paddle::platform::errors::InvalidArgument( + "The pointer of analysis configuration shouldn't be nullptr")); + config->config.EnableDlnne(min_subgraph_size); +} + +bool PD_DlnneEnabled(const PD_AnalysisConfig* config) { + PADDLE_ENFORCE_NOT_NULL( + config, + paddle::platform::errors::InvalidArgument( + "The pointer of analysis configuration shouldn't be nullptr")); + return config->config.dlnne_enabled(); +} + void PD_SwitchIrDebug(PD_AnalysisConfig* config, bool x) { PADDLE_ENFORCE_NOT_NULL( config, diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7b17899d3da2383d7aef0593cd7f40c529afbdaf..60fa8e319d954720aa091b45064f9f798354be2f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -42,6 +42,10 @@ if (WITH_GPU AND TENSORRT_FOUND) add_subdirectory(tensorrt) endif() +if (WITH_DLNNE) + add_subdirectory(dlnne) +endif() + if (WITH_LITE) add_subdirectory(lite) endif() diff --git a/paddle/fluid/operators/dlnne/CMakeLists.txt b/paddle/fluid/operators/dlnne/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fe9cf214eaa700326ea84dd6d4b3a6001c23365 --- /dev/null +++ b/paddle/fluid/operators/dlnne/CMakeLists.txt @@ -0,0 +1,54 @@ +# compile flags +set(DLNNE_FLAGS + -Wno-error=non-virtual-dtor + -Wno-error=unused-variable + -Wno-error=attributes + ${fsanitize} +) +foreach(flag ${DLNNE_FLAGS}) + safe_set_cflag(CMAKE_C_FLAGS ${flag}) + safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) +endforeach() + + +# add nne +find_path(DLNNE_INCLUDE_DIR dlnne.h + PATHS + $ENV{SOFTWARE_SOURCE_DIR} $ENV{SOFTWARE_SOURCE_DIR}/driver/nne/include + NO_DEFAULT_PATH +) + +find_library(DLNNE_LIB libdlnne.so + PATHS + $ENV{SOFTWARE_BUILD_DIR} $ENV{SOFTWARE_BUILD_DIR}/driver/nne + NO_DEFAULT_PATH +) + +find_path(CUDA_INCLUDE_DIR cuda.h + $ENV{SOFTWARE_BUILD_DIR}/llvm-project-10/cuda/include +) + +find_library(CURT_LIB libcurt.so + PATHS + $ENV{SOFTWARE_BUILD_DIR} $ENV{SOFTWARE_BUILD_DIR}/llvm-project-10/cuda/lib + NO_DEFAULT_PATH +) + + +message("DLNNE_INCLUDE_DIR: "${DLNNE_INCLUDE_DIR}) +message("DLNNE_LIB: "${DLNNE_LIB}) +message("CUDA_INCLUDE_DIR: "${CUDA_INCLUDE_DIR}) +message("CURT_LIB: "${CURT_LIB}) + +include_directories("${DLNNE_INCLUDE_DIR}") +include_directories("${CUDA_INCLUDE_DIR}") + +op_library(dlnne_engine_op DEPS ${GLOB_OPERATOR_DEPS} framework_proto boost device_context op_registry scope) + +#message("PYBIND_FILE:${pybind_file}") +#file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(dlnne_engine);\n") +#endif() + +target_link_libraries(dlnne_engine_op ${DLNNE_LIB} ${CURT_LIB}) + +cc_test(test_dlnne_engine_op SRCS dlnne_engine_op_test.cc DEPS dlnne_engine_op analysis) diff --git a/paddle/fluid/operators/dlnne/dlnne_engine_op.cc b/paddle/fluid/operators/dlnne/dlnne_engine_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4654e6a9f978a2885c369c97515aa1c6b1085245 --- /dev/null +++ b/paddle/fluid/operators/dlnne/dlnne_engine_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dlnne/dlnne_engine_op.h" + +namespace paddle { +namespace inference { + +void CopyTensorDeviceToCpu(void* dst_ptr, void* src_ptr, int total_bytes) { + cudaDeviceSynchronize(); + cudaMemcpy(dst_ptr, src_ptr, total_bytes, cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); +} +void CopyTensorCpuToDevice(void* dst_ptr, void* src_ptr, int total_bytes) { + cudaDeviceSynchronize(); + cudaMemcpy(dst_ptr, src_ptr, total_bytes, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); +} + +} // namespace inference + +namespace operators { + +class DlnneEngineOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Xs", "A list of inputs.").AsDuplicable(); + AddOutput("Ys", "A list of outputs").AsDuplicable(); + AddAttr("subgraph", "the subgraph."); + AddAttr( + "engine_key", + "The engine_key here is used to distinguish different DLNNE Engines"); + AddAttr("sub_block", "the trt block"); + AddComment("Dlnne engine operator."); + } +}; + +class DlnneEngineInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(dlnne_engine, ops::DlnneEngineOp, ops::DlnneEngineOpMaker); diff --git a/paddle/fluid/operators/dlnne/dlnne_engine_op.h b/paddle/fluid/operators/dlnne/dlnne_engine_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d426876c18fa5e7033c0787e8cec82758c3517e8 --- /dev/null +++ b/paddle/fluid/operators/dlnne/dlnne_engine_op.h @@ -0,0 +1,351 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include // NOTLINT +#include // NOTLINT +#include // NOTLINT + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/inference/analysis/helper.h" + +namespace dl { +namespace nne { +class Builder; +class Engine; +class Network; +class Parser; +class ExecutionContext; +} // namespace nne +} // namespace dl + +namespace paddle { +namespace inference { +class NneDeleter { + public: + NneDeleter() {} + + template + inline void operator()(T *ptr) { + if (ptr != nullptr) { + ptr->Destroy(); + } + } +}; + +void CopyTensorDeviceToCpu(void *dst_ptr, void *src_ptr, int total_bytes); + +void CopyTensorCpuToDevice(void *dst_ptr, void *src_ptr, int total_bytes); + +template +struct Singleton; +} // namespace inference +} // namespace paddle + +namespace paddle { + +namespace operators { + +class DlnneEngineOp : public framework::OperatorBase { + private: + std::vector input_names_; + std::unordered_set param_names_; + std::string engine_key_; + int num_inputs; + int num_outputs; + std::vector output_names; + std::vector input_names; + + dl::nne::Builder *builder; + dl::nne::Parser *parser; + dl::nne::Network *network; + dl::nne::ExecutionContext *context; + dl::nne::Engine *engine; + + unsigned int engine_input_size; + std::vector InputIndexToBindIndex_; + + public: + DlnneEngineOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : framework::OperatorBase(type, inputs, outputs, attrs) { + input_names_ = Inputs("Xs"); + engine_key_ = Attr("engine_key"); + auto params = Attr>("parameters"); + for (const auto ¶m : params) { + param_names_.insert(param); + } + + num_inputs = 0; + for (const auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + num_inputs += 1; + input_names.push_back(x); + } + + num_outputs = Outputs("Ys").size(); + for (const auto &y : Outputs("Ys")) { + VLOG(4) << "y: " << y << std::endl; + output_names.push_back(y); + } + + // onnx path + std::stringstream filename; + std::string current_path = "."; + char *buffer; + if ((buffer = getcwd(NULL, 0)) != NULL) { + current_path = buffer; + } else { + current_path = "."; + } + filename << current_path << "/dump/" << engine_key_ << "/" << engine_key_ + << ".onnx"; + + builder = dl::nne::CreateInferBuilder(); + PADDLE_ENFORCE_NE(builder, nullptr, platform::errors::Unavailable( + "nne create builder failed")); + parser = dl::nne::CreateParser(); + PADDLE_ENFORCE_NE(parser, nullptr, platform::errors::Unavailable( + "nne create parser failed")); + + network = builder->CreateNetwork(); + + LOG(INFO) << "set output for dlnne"; + for (std::string &output_op_name : output_names) + parser->RegisterOutput(output_op_name.c_str()); + + LOG(INFO) << "parser onnx for dlnne"; + parser->Parse(filename.str().c_str(), *network); + + LOG(INFO) << "build network"; + engine = builder->BuildEngine(*network); + + // total size = input_size+output_size + engine_input_size = num_inputs + num_outputs; + for (std::string &input_name : input_names) { + int BindIndex = engine->GetBindingIndex(input_name.c_str()); + InputIndexToBindIndex_.push_back(BindIndex); + } + + for (std::string &output_name : output_names) { + int BindIndex = engine->GetBindingIndex(output_name.c_str()); + InputIndexToBindIndex_.push_back(BindIndex); + } + + // context + context = engine->CreateExecutionContext(); + } + + ~DlnneEngineOp() { + network->Destroy(); + context->Destroy(); + engine->Destroy(); + parser->Destroy(); + builder->Destroy(); + } + + protected: + void RunDlnneOnCreateEngine(const framework::Scope &scope, + const platform::Place &dev_place) const { + PADDLE_ENFORCE_EQ( + input_names_.empty(), false, + platform::errors::PreconditionNotMet( + "Dlnne engine needs at least one input, but no input is found. " + "Please check if you set the input correctly.")); + + std::vector input_buffers(num_inputs); + std::vector cpu_input_buffers(num_inputs); + std::vector> input_shapes(num_inputs); + std::vector input_data_types(num_inputs); + std::vector input_bytes(num_inputs); + + int index = 0; + for (const auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + // convert input and copy to Dlnne engine's buffer + auto &t = + inference::analysis::GetFromScope(scope, x); + + const int bind_index = index; + index++; + int64_t data_bytes; + int32_t dtype; + auto type = t.type(); + data_bytes = 1; + void *buffer = nullptr; + if (type == framework::proto::VarType::FP32) { + buffer = static_cast(t.data()); + data_bytes = 4; + dtype = 0; + } else if (type == framework::proto::VarType::INT64) { + buffer = static_cast(t.data()); + data_bytes = 8; + dtype = 1; + } else if (type == framework::proto::VarType::INT32) { + buffer = static_cast(t.data()); + data_bytes = 4; + dtype = 2; + } else { + PADDLE_THROW(platform::errors::Fatal( + "The DLNNE Engine OP only support float/int32_t/int64_t input.")); + } + input_buffers[bind_index] = buffer; + + auto t_shape = framework::vectorize(t.dims()); + std::vector runtime_input_shape(t_shape.begin(), t_shape.end()); + for (auto &size : t_shape) { + data_bytes = data_bytes * size; + } + + VLOG(4) << "buffers_size:" << data_bytes; + cpu_input_buffers[bind_index] = + input_buffers[bind_index]; // malloc(data_bytes); + input_shapes[bind_index] = runtime_input_shape; + input_data_types[bind_index] = dtype; + input_bytes[bind_index] = data_bytes; + } + + // output shape + std::vector> out_shapes; + std::vector output_bytes; + for (int i = 0; i < num_outputs; i++) { + int index = engine->GetBindingIndex(output_names[i].c_str()); + dl::nne::Dims out_dim = engine->GetBindingDimensions(index); + std::vector shape(out_dim.nbDims); + for (int dim = 0; dim < out_dim.nbDims; dim++) { + shape[dim] = (out_dim.d[dim]); + } + + out_shapes.push_back(shape); + int64_t data_bytes; + + // float32 + data_bytes = 4; + for (auto &size : shape) { + data_bytes = data_bytes * size; + } + VLOG(4) << "data_bytes: " << data_bytes; + output_bytes.push_back(data_bytes); + } + + int bind_index = 0; + std::vector cpu_output_buffers(num_outputs); + std::vector output_buffers(num_outputs); + std::vector output_dtypes(num_outputs); + + for (const auto &y : Outputs("Ys")) { + auto *fluid_v = scope.FindVar(y); + PADDLE_ENFORCE_NOT_NULL( + fluid_v, + platform::errors::NotFound( + "Output variable %s is not found in DLNNE subgraph.", y)); + + auto *fluid_t = fluid_v->GetMutable(); + + VLOG(4) << "out_shapes[bind_index] dim:" << out_shapes[bind_index].size(); + fluid_t->Resize(framework::make_ddim(out_shapes[bind_index])); + + int32_t dtype; + output_buffers[bind_index] = fluid_t->mutable_data( + BOOST_GET_CONST(platform::CPUPlace, dev_place)); + dtype = 0; + cpu_output_buffers[bind_index] = + output_buffers[bind_index]; // malloc(data_bytes); + output_dtypes[bind_index] = dtype; + bind_index++; + } + + std::vector engine_input_ptr(engine_input_size); + + // set input_ptr + for (unsigned int i = 0; i < engine_input_size; i++) { + if (InputIndexToBindIndex_[i] < 0) continue; + + if (engine->BindingIsInput(InputIndexToBindIndex_[i])) { + // copy cpu buffer to gpu buffer + int64_t total_bytes; + total_bytes = input_bytes[i]; + VLOG(4) << "input_bytes: " << total_bytes; + + void *gpu_ptr; + cudaMalloc(&gpu_ptr, total_bytes); + engine_input_ptr[InputIndexToBindIndex_[i]] = gpu_ptr; + + paddle::inference::CopyTensorCpuToDevice( + gpu_ptr, reinterpret_cast(cpu_input_buffers[i]), + total_bytes); + + } else { + int64_t total_size; + total_size = output_bytes[i - input_names.size()]; + VLOG(4) << "output_bytes: " << total_size; + void *gpu_ptr; + cudaMalloc(&gpu_ptr, total_size); + engine_input_ptr[InputIndexToBindIndex_[i]] = gpu_ptr; + } + } + + clock_t startTime, endTime; + startTime = clock(); + context->Execute(1, engine_input_ptr.data()); + endTime = clock(); + double during_ms = + static_cast(endTime - startTime) / CLOCKS_PER_SEC * 1000; + LOG(INFO) << "dlNNE execute time: " << during_ms << " ms"; + + bind_index = 0; + for (unsigned int i = 0; i < engine_input_size; i++) { + if (InputIndexToBindIndex_[i] < 0) continue; + + if (i >= input_names.size()) { + void *cpu_ptr = cpu_output_buffers[i - input_names.size()]; + int64_t size; + size = output_bytes[i - input_names.size()]; + paddle::inference::CopyTensorDeviceToCpu( + cpu_ptr, engine_input_ptr[InputIndexToBindIndex_[i]], size); + // dtype: float32 + int32_t dtypes; + dtypes = 0; + + cpu_output_buffers[bind_index] = cpu_ptr; + output_dtypes[bind_index] = dtypes; + bind_index++; + } + cudaFree(engine_input_ptr[InputIndexToBindIndex_[i]]); + } + } + + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + RunDlnneOnCreateEngine(scope, dev_place); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/dlnne/dlnne_engine_op_test.cc b/paddle/fluid/operators/dlnne/dlnne_engine_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..caf1a80fcc737f1883ecb7b94e43e383e8b830d4 --- /dev/null +++ b/paddle/fluid/operators/dlnne/dlnne_engine_op_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dlnne/dlnne_engine_op.h" +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" +#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h" + +USE_NO_KERNEL_OP(dlnne_engine); +namespace paddle { +namespace operators { + +namespace { +void CreateCUDATensor(framework::Scope* scope, const std::string& name, + const std::vector& shape) { + auto* var = scope->Var(name); + auto* tensor = var->GetMutable(); + auto dims = framework::make_ddim(shape); + tensor->Resize(dims); + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + inference::tensorrt::RandomizeTensor(tensor, place, ctx); +} + +void AddTensorToBlockDesc(framework::proto::BlockDesc* block, + const std::string& name, + const std::vector& shape) { + using framework::proto::VarType; + auto* var = block->add_vars(); + framework::VarDesc desc(name); + desc.SetType(VarType::LOD_TENSOR); + desc.SetDataType(VarType::FP32); + desc.SetShape(shape); + *var = *desc.Proto(); +} + +} // namespace + +using inference::analysis::SetAttr; + +TEST(DlnneEngineOp, manual) { + framework::ProgramDesc program; + auto* block_ = program.Proto()->add_blocks(); + block_->set_idx(0); + block_->set_parent_idx(-1); + + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + LOG(INFO) << "create fc op"; + auto* fc0 = block_desc.AppendOp(); + fc0->SetType("fc"); + fc0->SetInput("X", std::vector({"x"})); // 4 x 1 x 1 + fc0->SetInput("Y", std::vector({"y"})); // 4 x 6 + fc0->SetOutput("Out", std::vector({"z"})); // 6 x 1 x 1 + + LOG(INFO) << "create fc op"; + auto* fc1 = block_desc.AppendOp(); + fc1->SetType("fc"); + fc1->SetInput("X", std::vector({"z"})); + fc1->SetInput("Y", std::vector({"y0"})); // 6 x 8 + fc1->SetOutput("Out", std::vector({"z0"})); // 8 x 1 x 1 + + // Set inputs' variable shape in BlockDesc + // the batch size is 2, so the dims of 'x' is {2, 4, 1, 1} + AddTensorToBlockDesc(block_, "x", std::vector({2, 4, 1, 1})); + AddTensorToBlockDesc(block_, "y", std::vector({4, 6})); + AddTensorToBlockDesc(block_, "y0", std::vector({6, 8})); + AddTensorToBlockDesc(block_, "z", std::vector({2, 6})); + + // It is wired, need to copy manually. + *block_->add_ops() = *fc0->Proto(); + *block_->add_ops() = *fc1->Proto(); + + ASSERT_EQ(block_->ops_size(), 2); + + LOG(INFO) << "create dlnne desc"; + framework::OpDesc engine_op_desc(nullptr); + engine_op_desc.SetType("dlnne_engine"); + engine_op_desc.SetInput("Xs", std::vector({"x"})); + engine_op_desc.SetOutput("Ys", std::vector({"z0"})); + + engine_op_desc.SetBlockAttr("sub_block", &block_desc); + engine_op_desc.SetAttr("max_batch_size", static_cast(2)); + engine_op_desc.SetAttr("workspace_size", static_cast(1 << 20)); + engine_op_desc.SetAttr("parameters", std::vector({})); + engine_op_desc.SetAttr("engine_key", std::string("a_engine")); + engine_op_desc.SetAttr("calibration_engine_key", + std::string("a_calib_engine")); + engine_op_desc.SetAttr("predictor_id", 1); + engine_op_desc.SetAttr("calibration_data", std::string("")); + engine_op_desc.SetAttr("enable_int8", static_cast(false)); + engine_op_desc.SetAttr("enable_fp16", static_cast(false)); + engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); + engine_op_desc.SetAttr("output_name_mapping", + std::vector({"z0"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); + engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); + engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); + + LOG(INFO) << "create engine op"; + auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); + LOG(INFO) << "engine_op " << engine_op.get(); + + framework::Scope scope; + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + // Prepare variables. + CreateCUDATensor(&scope, "x", std::vector({2, 4})); + CreateCUDATensor(&scope, "y", std::vector({4, 6})); + CreateCUDATensor(&scope, "z", std::vector({2, 6})); + + CreateCUDATensor(&scope, "y0", std::vector({6, 8})); + CreateCUDATensor(&scope, "z0", std::vector({2, 8})); + + // Execute them. + LOG(INFO) << "engine_op run"; + engine_op->Run(scope, place); +} + +void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + + auto* block_ = program.Proto()->add_blocks(); + block_->set_idx(0); + block_->set_parent_idx(-1); + + using shape_t = std::vector; + + LOG(INFO) << "create block desc"; + framework::BlockDesc block_desc(&program, block_); + + auto AddFCLayer = [&](const std::string& x_name, const std::string& y_name, + const std::string& z_name, bool x_created, + const shape_t& x_shape, const shape_t& y_shape, + const shape_t& z_shape) { + LOG(INFO) << "create fc op"; + auto* fc = block_desc.AppendOp(); + fc->SetType("mul"); + fc->SetInput("X", std::vector({x_name})); + fc->SetInput("Y", std::vector({y_name})); + fc->SetOutput("Out", std::vector({z_name})); + + // Set inputs' variable shape in BlockDesc + if (!x_created) { + AddTensorToBlockDesc(block_, x_name, + std::vector({batch_size, input_dim, 1, 1})); + } + AddTensorToBlockDesc(block_, y_name, + std::vector({input_dim, output_dim})); + AddTensorToBlockDesc(block_, z_name, + std::vector({batch_size, output_dim})); + + // Prepare variables. + if (!x_created) { + CreateCUDATensor(&scope, x_name, std::vector(x_shape)); + } + CreateCUDATensor(&scope, y_name, std::vector(y_shape)); + CreateCUDATensor(&scope, z_name, std::vector(z_shape)); + + // It is wired, need to copy manually. + *block_->add_ops() = *fc->Proto(); + }; + + // Test with 4 layer FC + AddFCLayer("x0", "y0", "z0", false, {batch_size, input_dim}, + {input_dim, output_dim}, {batch_size, output_dim}); + AddFCLayer("z0", "y1", "z1", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + AddFCLayer("z1", "y2", "z2", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + AddFCLayer("z2", "y3", "z3", true, {}, {output_dim, output_dim}, + {batch_size, output_dim}); + + LOG(INFO) << "create dlnne desc"; + framework::OpDesc engine_op_desc(nullptr); + engine_op_desc.SetType("dlnne_engine"); + engine_op_desc.SetInput("Xs", std::vector({"x0"})); + engine_op_desc.SetOutput("Ys", std::vector({"z3"})); + + engine_op_desc.SetBlockAttr("sub_block", &block_desc); + engine_op_desc.SetAttr("max_batch_size", static_cast(batch_size)); + engine_op_desc.SetAttr("workspace_size", static_cast(1 << 20)); + engine_op_desc.SetAttr("parameters", + std::vector({"y0", "y1", "y2", "y3"})); + engine_op_desc.SetAttr("engine_key", std::string("b_engine")); + engine_op_desc.SetAttr("calibration_engine_key", + std::string("b_calib_engine")); + engine_op_desc.SetAttr("predictor_id", 1); + engine_op_desc.SetAttr("calibration_data", std::string("")); + engine_op_desc.SetAttr("enable_int8", static_cast(false)); + engine_op_desc.SetAttr("enable_fp16", static_cast(false)); + engine_op_desc.SetAttr("use_calib_mode", static_cast(false)); + engine_op_desc.SetAttr("output_name_mapping", + std::vector({"z3"})); + engine_op_desc.SetAttr("origin_output_dims", std::vector({2})); + engine_op_desc.SetAttr("subgraph", std::string(block_->SerializeAsString())); + engine_op_desc.SetAttr("engine_serialized_data", std::string("")); + int device_id = 0; + engine_op_desc.SetAttr("gpu_id", device_id); + + auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc); + + // Execute them. + engine_op->Run(scope, place); +} + +// Test with a larger FC layer. +TEST(DlnneEngineOp, fc) { Execute(40, 28, 28); } + +} // namespace operators +} // namespace paddle + +USE_TRT_CONVERTER(fc) diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index dd9cb65142a3de71fec247185328ebed8c98a03a..606af27f6baf2a06a3670e6d87065c68513a7241 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -512,6 +512,8 @@ void BindAnalysisConfig(py::module *m) { py::arg("dla_core") = 0) .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) + .def("enable_dlnne", &AnalysisConfig::EnableDlnne, + py::arg("min_subgraph_size") = 3) .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32, py::arg("zero_copy") = false,