diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 997022abde3f9c500098573d47dc08c1e7e107e6..17a2c5852cbf87268a58439a8cb896aa3ec362fa 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -250,6 +250,22 @@ struct Argument { DECL_ARGUMENT_FIELD(use_dlnne, UseDlnne, bool); DECL_ARGUMENT_FIELD(dlnne_min_subgraph_size, DlnneMinSubgraphSize, int); DECL_ARGUMENT_FIELD(dlnne_max_batch_size, DlnneMaxBatchSize, int); + DECL_ARGUMENT_FIELD(dlnne_use_static_batch, DlnneUseStaticBatch, bool); + DECL_ARGUMENT_FIELD(dlnne_weight_share_mode, + DlnneWeightShareMode, + std::string); + DECL_ARGUMENT_FIELD(dlnne_disable_nodes_by_outputs, + DlnneDisableNodesByOutputs, + std::unordered_set); + DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool); + DECL_ARGUMENT_FIELD(dlnne_precision_mode, + DlnnePrecisionMode, + AnalysisConfig::Precision); + + using dlnne_input_shape_type = std::map>; + DECL_ARGUMENT_FIELD(dlnne_input_shape_dict, + DlnneInputShapeDict, + dlnne_input_shape_type); DECL_ARGUMENT_FIELD(dlnne_workspace_size, DlnneWorkspaceSize, int); DECL_ARGUMENT_FIELD(lite_passes_filter, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 723a787722143dc8d497c9a143469bfa7b53edd3..f86a22e3db9e1d4960cc99d74d4b7c28b493d28d 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -209,8 +209,23 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("disable_trt_plugin_fp16", new bool(argument->disable_trt_plugin_fp16())); } else if (pass_name == "dlnne_subgraph_pass") { + auto precision_mode = argument->dlnne_precision_mode(); pass->Set("min_subgraph_size", new int(argument->dlnne_min_subgraph_size())); + pass->Set("max_batch_size", new int(argument->dlnne_max_batch_size())); + pass->Set("use_static_batch", + new bool(argument->dlnne_use_static_batch())); + pass->Set("weight_share_mode", + new std::string(argument->dlnne_weight_share_mode())); + pass->Set("disable_nodes_by_outputs", + new std::unordered_set( + argument->dlnne_disable_nodes_by_outputs())); + pass->Set("use_calib_mode", new bool(argument->dlnne_use_calib_mode())); + pass->Set("precision_mode", + new AnalysisConfig::Precision(precision_mode)); + pass->Set("input_shape_dict", + new std::map>( + argument->dlnne_input_shape_dict())); pass->Set("program", new framework::ProgramDesc *(&argument->main_program())); } diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h b/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h deleted file mode 100644 index ae977c1403a8793b0611496702515f1df952d5a1..0000000000000000000000000000000000000000 --- a/paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -namespace paddle { -namespace inference { - -int RegisterPyFunc(const std::string& name, void* pfn); -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc index 93fbc1d882be99ca3f2f510c52a1b6aaabd086dd..3056eff9ae15c738d0749e2c99ecc228e74b97e3 100644 --- a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc @@ -11,87 +11,339 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h" - #include #include #include #include #include +#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/dlnne_reg_py.h" +#include "paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace inference { +namespace analysis { -int (*PyConvertGraph)(const char *graph_name); +using framework::ir::Node; -int RegisterPyFunc(const std::string &name, void *pfn) { - if (name.compare("convert_graph") == 0) { - PyConvertGraph = reinterpret_cast(pfn); +void analysis::DlnneSubgraphPass::InferShapeForDlnneMainGraph() const { + // copy from paddle2onnx + static std::unordered_set OP_WITHOUT_KERNEL_SET = { + "feed", + "fetch", + "recurrent", + "go", + "rnn_memory_helper_grad", + "conditional_block", + "while", + "send", + "recv", + "listen_and_serv", + "fl_listen_and_serv", + "ncclInit", + "select", + "checkpoint_notify", + "gen_bkcl_id", + "c_gen_bkcl_id", + "gen_nccl_id", + "c_gen_nccl_id", + "c_comm_init", + "c_sync_calc_stream", + "c_sync_comm_stream", + "queue_generator", + "dequeue", + "enqueue", + "heter_listen_and_serv", + "c_wait_comm", + "c_wait_compute"}; + + std::string bilinear_interp_v2_type = "bilinear_interp_v2"; + auto input_dict = + Get>>("input_shape_dict"); + + framework::ProgramDesc *global_program = + Get("program"); + auto block = global_program->MutableBlock(framework::kRootBlockIndex); + for (auto kv : input_dict) { + auto var = block->FindVar(kv.first); + if (var != nullptr) { + var->SetShape(kv.second); + } else { + VLOG(4) << "input_name:" << kv.first << " not find in all input vars"; + } } - return 0; -} -int ConvertGraph(std::string graph_name) { - LOG(INFO) << "starting doing convert_graph"; + std::vector all_ops = block->AllOps(); + + for (size_t i = 0; i < block->OpSize(); i++) { + // the output_shape of bilinear_interp_v2 cannot be inferd by input shape, + // it also need the value of input tensor, so when call OpDesc->InferShape, + // the output_shape of bilinear_interp_v2 is still dynamic, here we try to + // infer the output_shape of bilinear_interp_v2 infer shape for + // bilinear_interp_v2 + if (block->Op(i)->Type() == bilinear_interp_v2_type) { + framework::VariableNameMap input_name_map = block->Op(i)->Inputs(); + std::vector input_name_vec = input_name_map["OutSize"]; + PADDLE_ENFORCE_EQ( + input_name_vec.size(), + 1, + platform::errors::PreconditionNotMet( + "The 'bilinear_interp_v2 op' input 'OutSize' size must be 1 ")); + + // find shape->slice->bilinear_interp_v2 pattern + int start_id = 0; + int end_id = 0; + std::vector slice_input_name_vec; + for (auto *i_op : all_ops) { + if (i_op->HasOutput("Out")) { + auto it = find(i_op->Output("Out").begin(), + i_op->Output("Out").end(), + input_name_vec[0]); + if (it != i_op->Output("Out").end()) { + slice_input_name_vec = i_op->Input("Input"); + PADDLE_ENFORCE_EQ( + slice_input_name_vec.size(), + 1, + platform::errors::PreconditionNotMet( + "The 'slice op' input 'Input' size must be 1 ")); + + auto start_vec = i_op->GetAttrIfExists>("starts"); + start_id = start_vec[0]; + auto end_vec = i_op->GetAttrIfExists>("ends"); + end_id = end_vec[0]; + break; + } + } + } - PyConvertGraph(graph_name.c_str()); + std::vector shape_input_name_vec; + for (auto *i_op : all_ops) { + if (i_op->HasOutput("Out")) { + auto it = find(i_op->Output("Out").begin(), + i_op->Output("Out").end(), + slice_input_name_vec[0]); + if (it != i_op->Output("Out").end()) { + shape_input_name_vec = i_op->Input("Input"); + PADDLE_ENFORCE_EQ( + slice_input_name_vec.size(), + 1, + platform::errors::PreconditionNotMet( + "The 'shape op' input 'Input' size must be 1 ")); + break; + } + } + } + auto target_var = block->FindVarRecursive(shape_input_name_vec[0]); + std::vector target_shape = target_var->GetShape(); + size_t target_shape_len = target_shape.size(); + if (start_id < 0) { + start_id = target_shape_len + start_id; + } else if (start_id > static_cast(target_shape_len)) { + start_id = target_shape_len; + } - return 0; -} + if (end_id < 0) { + end_id = target_shape_len + end_id; + } else if (end_id > static_cast(target_shape_len)) { + end_id = target_shape_len; + } -namespace analysis { + if (start_id < end_id) { + std::vector OutSize_dims(target_shape.begin() + start_id, + target_shape.begin() + end_id); + + framework::VariableNameMap output_name_map = block->Op(i)->Outputs(); + std::vector output_name_vec = output_name_map["Out"]; + auto out_var = block->FindVarRecursive(output_name_vec[0]); + PADDLE_ENFORCE_NOT_NULL( + out_var, + platform::errors::NotFound( + "bilinear_interp_v2 op's output %s is not found in the block.", + output_name_vec[0])); + std::vector ori_shape = out_var->GetShape(); + std::string data_layout = + block->Op(i)->GetAttrIfExists("data_layout"); + size_t start_dim = 0; + size_t end_dim = 0; + + if (data_layout == "NCHW") { + start_dim = 2; + end_dim = ori_shape.size(); + } else { + start_dim = 1; + end_dim = ori_shape.size() - 1; + } + for (size_t i_dim = start_dim; i_dim < end_dim; i_dim++) { + ori_shape[i_dim] = OutSize_dims[i_dim - start_dim]; + } -using framework::ir::Node; + VLOG(4) << "Set bilinear_interp_v2 shape: " << ori_shape[2] << ", " + << ori_shape[3]; + out_var->SetShape(ori_shape); + } + + } else { + if (OP_WITHOUT_KERNEL_SET.find(block->Op(i)->Type()) == + OP_WITHOUT_KERNEL_SET.end()) + block->Op(i)->InferShape(*block); + } + } +} + +bool analysis::DlnneSubgraphPass::IsDynamicOp(std::string var_name, + bool use_static_batch) const { + framework::ProgramDesc *global_program = + Get("program"); + auto block = global_program->MutableBlock(framework::kRootBlockIndex); + auto var = block->FindVar(var_name); + + if (var != nullptr) { + std::vector var_shape = var->GetShape(); + size_t start_idx = use_static_batch ? 1 : 0; + for (; start_idx < var_shape.size(); start_idx++) { + if (var_shape[start_idx] < 1) { + return false; + } + } + } + return true; +} void analysis::DlnneSubgraphPass::ApplyImpl(framework::ir::Graph *graph) const { + framework::ir::FusePassBase::Init("dlnne_subgraph_pass", graph); + + InferShapeForDlnneMainGraph(); + static std::unordered_set teller_set{ + "nearest_interp_v2", "mul", "matmul", + "matmul_v2", + "flatten_contiguous_range", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "softplus", "hard_swish", + "hard_sigmoid", "depthwise_conv2d", "batch_norm", + "exp", "concat", + "clip", + "cast", "tanh", "pad", "elementwise_add", "elementwise_mul", + "elementwise_sub", + "elementwise_div", + "elementwise_pow", "dropout", + // "deformable_conv", + "prelu", "conv2d_transpose", "leaky_relu", - // "fc", + "log", + "fc", "shuffle_channel", "swish", "split", - // "instance_norm", + "instance_norm", "gelu", - // "layer_norm", - // "scale", - // "stack", + "layer_norm", + "scale", + "slice", + "stack", "relu6", "reshape2", "transpose2", "concat", "slice", + "fill_constant", + "fill_constant_batch_size_like", + "shape", + "unsqueeze2", + "pad3d", + "squeeze2", + "bilinear_interp_v2" + // "yolo_box" }; - framework::ir::FusePassBase::Init("dlnne_subgraph_pass", graph); + // the op which output is special, need special process + static std::unordered_set special_output_op_set{ + "transpose2", + "fill_constant_batch_size_like", + "flatten_contiguous_range", + "batch_norm", + "unsqueeze2", + }; + + // the op when it's shape is dynamic still can be fused by + // dlnne_engine_op + static std::unordered_set dynamic_pass_op_set{ + "reshape2", + }; + auto disable_nodes_by_outputs = + Get>("disable_nodes_by_outputs"); + bool use_static_batch = Get("use_static_batch"); auto teller = [&](const framework::ir::Node *node) { - if (!node->IsOp() || !node->Op()) return false; - return teller_set.find(node->Op()->Type()) != teller_set.end(); + if (!node->IsOp() || !node->Op()) { + return false; + } + if (teller_set.find(node->Op()->Type()) == teller_set.end()) { + VLOG(4) << "don't support op:" << node->Op()->Type(); + return false; + } else { + bool flag = true; + // check node output + if (dynamic_pass_op_set.find(node->Op()->Type()) != + dynamic_pass_op_set.end()) { + flag = true; + } else if (special_output_op_set.find(node->Op()->Type()) == + special_output_op_set.end()) { + for (auto *x : node->outputs) { + std::string var_name = x->Var()->Name(); + flag = IsDynamicOp(var_name, use_static_batch); + if (!flag) break; + } + } else { + std::string var_name = node->outputs[0]->Var()->Name(); + flag = IsDynamicOp(var_name, use_static_batch); + } + // check node input + if (flag) { + for (auto *x : node->inputs) { + std::string var_name = x->Var()->Name(); + flag = IsDynamicOp(var_name, use_static_batch); + if (!flag) break; + } + } + if (!flag) { + VLOG(4) << "don't support dynamic shape:" << node->Op()->Type(); + } + bool flag2 = true; + for (auto *x : node->outputs) { + if (disable_nodes_by_outputs.find(x->Name()) != + disable_nodes_by_outputs.end()) { + flag2 = false; + } + } + if (!flag2) { + VLOG(4) << "user don't use " << node->Name() << "..."; + } + return flag && flag2; + } }; framework::ir::SubGraphFuser fuser( @@ -153,6 +405,45 @@ std::string replace_name(std::string name, return r_name; } +auto fix_batch_as_one( + std::unordered_map *name_var_desc, + std::set *valid_input_names, + bool use_static_batch = false) { + std::unordered_map> name_var_shape; + + if (use_static_batch) { + std::set names; + names.insert(valid_input_names->begin(), valid_input_names->end()); + + for (auto name : names) { + if (name_var_desc->find(name) != name_var_desc->end()) { + auto var_desc = (*name_var_desc)[name]; + auto sp = var_desc->GetShape(); + if (sp[0] == -1) { + sp[0] = 1; + name_var_shape[name] = sp; + std::stringstream sp_str; + copy(sp.begin(), + sp.end(), + std::ostream_iterator(sp_str, ",")); + + LOG(INFO) + << "Warning: fix var:" << name << " batch,shape is [" + << sp_str.str() + << "],we assume subgraph's inputs/outputs first dim is batch," + << "but when the first dim is not mean batch " + << "we suggest you use fix shape model...."; + } + } + } + } + return name_var_shape; +} +/* +there are two ProgramDesc in the function, global_program is used for generate a +Dlnne op, dump_program is used for dump the subgraph to onnx subgraph which is +loaded by Dlnne op +*/ void DlnneSubgraphPass::CreateDlnneOp( framework::ir::Node *node, framework::ir::Graph *graph, @@ -172,22 +463,58 @@ void DlnneSubgraphPass::CreateDlnneOp( block_desc.Proto()->set_idx(0); LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; // for debug - framework::ProgramDesc tmp_dump_program_desc; - auto *tmp_dump_main_block = tmp_dump_program_desc.MutableBlock(0); + framework::ProgramDesc *global_program = + Get("program"); + const framework::BlockDesc &main_block = + global_program->Block(framework::kRootBlockIndex); - std::unordered_map name_var_desc; - std::set name_var_input_nodes; - std::set name_var_output_nodes; - std::set name_ops; + std::set input_names; + std::set input_names_with_id; + std::vector params; + std::set valid_input_names; + // if we delete fluid copy of params shared by more than 1 ops, there will be + // problem, so we filter them out. + + // The node->inputs contains input tensors and parameters. + for (auto *x : node->inputs) { + input_names.insert(x->Name()); + input_names_with_id.insert(x->Name() + std::to_string(x->id())); + if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) { + params.push_back(x->Name()); + } + if (std::find(graph_params.begin(), graph_params.end(), x->Name()) == + graph_params.end()) { + valid_input_names.insert(x->Name()); + } + } + + std::set output_names; + std::set output_names_with_id; + std::vector origin_output_dims; + std::set valid_output_names; + for (auto *x : node->outputs) { + origin_output_dims.push_back(x->Var()->GetShape().size()); + output_names.insert(x->Name()); + output_names_with_id.insert(x->Name() + std::to_string(x->id())); + if (std::find(graph_params.begin(), graph_params.end(), x->Name()) == + graph_params.end()) { + valid_output_names.insert(x->Name()); + } + } + + auto *child_block = global_program->AppendBlock(main_block); + framework::ProgramDesc dump_program; + auto *export_block = dump_program.MutableBlock(framework::kRootBlockIndex); + std::unordered_map name_var_desc; for (auto *node : subgraph) { auto *op = block_desc.AppendOp(); *op->Proto() = *node->Op()->Proto(); - - // debug + auto *child_op = child_block->AppendOp(); + *child_op->Proto() = *node->Op()->Proto(); + // generate op by node to append on block { - name_ops.insert(node->Name()); - auto *tmp_dump_new_block_op = tmp_dump_main_block->AppendOp(); + auto *export_op = export_block->AppendOp(); framework::OpDesc op_desc; op_desc.CopyFrom(*node->Op()); @@ -204,77 +531,69 @@ void DlnneSubgraphPass::CreateDlnneOp( op_desc.Rename(argument_name, replace_name(argument_name, "/", ".")); } } - *tmp_dump_new_block_op->Proto() = *op_desc.Proto(); + *export_op->Proto() = *op_desc.Proto(); for (auto *x : node->inputs) { if (x->IsVar()) { - name_var_desc[x->Name()] = x->Var(); + auto var_desc_infer = main_block.FindVarRecursive(x->Name()); + if (var_desc_infer != nullptr) { + name_var_desc[x->Name()] = var_desc_infer; + } else { + name_var_desc[x->Name()] = x->Var(); + } } - if (std::count(graph_params.begin(), graph_params.end(), x->Name()) == - 0) - name_var_input_nodes.insert(x->Name()); } for (auto *x : node->outputs) { if (x->IsVar()) { - name_var_desc[x->Name()] = x->Var(); + auto var_desc_infer = main_block.FindVarRecursive(x->Name()); + if (var_desc_infer != nullptr) { + name_var_desc[x->Name()] = var_desc_infer; + } else { + name_var_desc[x->Name()] = x->Var(); + } } - if (std::count(graph_params.begin(), graph_params.end(), x->Name()) == - 0) - name_var_output_nodes.insert(x->Name()); } } } - std::set valid_input_names; - std::set valid_output_names; - for (auto name : name_var_output_nodes) { - if (name_var_input_nodes.find(name) == name_var_input_nodes.end()) { - valid_output_names.insert(name); - } - } - for (auto name : name_var_input_nodes) { - if (name_var_output_nodes.find(name) == name_var_output_nodes.end()) { - valid_input_names.insert(name); - } + // starting fix bath as one + bool use_static_batch = Get("use_static_batch"); + auto name_shape_table = + fix_batch_as_one(*name_var_desc, *valid_input_names, use_static_batch); + + for (const auto &name_shape : name_shape_table) { + VLOG(4) << "Fix batch shape as one var name: " << name_shape.first; } // Then, we will use the input_names_with_id and output_names_with_id to // generate the engine key. // So, We use set instead of unordered_set here to ensure that the engine key // is unique. - std::set input_names; - std::set input_names_with_id; - std::vector params; - // if we delete fluid copy of params shared by more than 1 ops, there will be - // problem, so we filter them out. - - // The node->inputs contains input tensors and parameters. - for (auto *x : node->inputs) { - input_names.insert(x->Name()); - input_names_with_id.insert(x->Name() + std::to_string(x->id())); - if (std::count(graph_params.begin(), graph_params.end(), x->Name()) > 0) { - params.push_back(x->Name()); - } + auto engine_key = GenerateEngineKey( + input_names_with_id, output_names_with_id, std::to_string(0)); + auto precision_mode = Get("precision_mode"); + bool enable_int8 = false; + if (precision_mode == AnalysisConfig::Precision::kInt8) { + enable_int8 = true; } - - std::set output_names; - std::set output_names_with_id; - std::vector origin_output_dims; - for (auto *x : node->outputs) { - origin_output_dims.push_back(x->Var()->GetShape().size()); - output_names.insert(x->Name()); - output_names_with_id.insert(x->Name() + std::to_string(x->id())); + auto use_calib_mode = Get("use_calib_mode"); + + std::string calibration_data_path = "./calibration/dlnne_calib_" + engine_key; + bool calibration_mode = false; + if (enable_int8 && use_calib_mode && !PathExists(calibration_data_path)) { + calibration_mode = true; + MKDIR("./calibration"); + MKDIR(calibration_data_path.c_str()); } - - std::unordered_map output_name_map; - std::unordered_map graph_var_map; - - for (framework::ir::Node *node : graph->Nodes()) { - if (node->IsVar() && node->Var()) { - graph_var_map[node->Name()] = node; - } + VLOG(4) << "calibration_mode: " << calibration_mode; + std::stringstream ss; + ss << "engine_key:" << engine_key << " outputs:["; + for (auto name : valid_output_names) { + ss << name << ","; } + ss << "]"; + VLOG(4) << ss.str(); // Set attrs op_desc->SetType("dlnne_engine"); @@ -285,70 +604,98 @@ void DlnneSubgraphPass::CreateDlnneOp( op_desc->SetOutput("Ys", std::vector(valid_output_names.begin(), valid_output_names.end())); + op_desc->SetBlockAttr("sub_block", child_block); op_desc->SetAttr("parameters", params); - auto engine_key = GenerateEngineKey( - input_names_with_id, output_names_with_id, std::to_string(0)); op_desc->SetAttr("engine_key", engine_key); - auto *scope = param_scope(); - - { - std::set input_names; + op_desc->SetAttr("max_batch_size", Get("max_batch_size")); + op_desc->SetAttr("use_static_batch", Get("use_static_batch")); + op_desc->SetAttr("weight_share_mode", Get("weight_share_mode")); + op_desc->SetAttr("enable_int8", enable_int8); + op_desc->SetAttr("use_calib_mode", use_calib_mode); + op_desc->SetAttr("calibration_mode", calibration_mode); + op_desc->SetAttr("calibration_data_path", calibration_data_path); + + std::string subgraph_root_path = "./dump/" + engine_key; + op_desc->SetAttr("subgraph_root_path", subgraph_root_path); + + std::stringstream ins_stream; + for (auto name : valid_input_names) { + ins_stream << "," << name; + } + op_desc->SetAttr("valid_input_names", ins_stream.str().substr(1)); - for (auto name : name_var_input_nodes) { - if (name_var_output_nodes.find(name) == name_var_output_nodes.end()) { - input_names.insert(name); - } - } + std::stringstream outs_stream; + for (auto name : valid_output_names) { + outs_stream << "," << name; + } + op_desc->SetAttr("valid_output_names", outs_stream.str().substr(1)); + auto *scope = param_scope(); + { // add feed to subgraph: int input_idx = 0; - for (auto input_name : input_names) { - auto *feed0 = tmp_dump_main_block->AppendOp(); - feed0->SetType("feed"); - feed0->SetInput("X", {"feed"}); - feed0->SetOutput("Out", {input_name}); - feed0->SetAttr("col", input_idx); + for (auto input_name : valid_input_names) { + auto *feed1 = export_block->AppendOp(); + feed1->SetType("feed"); + feed1->SetInput("X", {"feed"}); + feed1->SetOutput("Out", {input_name}); + feed1->SetAttr("col", input_idx); input_idx++; } // add fetch to subgraph: int output_idx = 0; for (auto output_name : valid_output_names) { - auto *fetch0 = tmp_dump_main_block->AppendOp(); - fetch0->SetType("fetch"); - fetch0->SetInput("X", {output_name}); - fetch0->SetOutput("Out", {"out"}); - fetch0->SetAttr("col", output_idx); + auto *fetch1 = export_block->AppendOp(); + fetch1->SetType("fetch"); + fetch1->SetInput("X", {output_name}); + fetch1->SetOutput("Out", {"out"}); + fetch1->SetAttr("col", output_idx); output_idx++; } - mkdir("./dump", 0777); - std::string dir_name = "./dump/" + engine_key; - mkdir(dir_name.c_str(), 0777); - ofstream m_stream; - m_stream.open(dir_name + "/__model__", ios::out); - VLOG(4) << "name_var_desc size:" << name_var_desc.size(); for (auto &kv : name_var_desc) { - auto *new_add_var = tmp_dump_main_block->Proto()->add_vars(); - *new_add_var = *kv.second->Proto(); - auto *variable_tmp = scope->FindVar(kv.first); - if (variable_tmp != nullptr) { - *new_add_var->mutable_name() = replace_name(kv.first, "/", "."); - new_add_var->set_persistable(true); + auto *new_add_var1 = export_block->Proto()->add_vars(); + paddle::framework::VarDesc copy_var_desc(*(kv.second->Proto())); + + if (name_shape_table.find(kv.first) != name_shape_table.end()) { + copy_var_desc.SetShape(name_shape_table[kv.first]); + } + *new_add_var1 = *(copy_var_desc.Proto()); + + auto *variable_tmp1 = scope->FindVar(kv.first); + if (variable_tmp1 != nullptr) { + *new_add_var1->mutable_name() = replace_name(kv.first, "/", "."); + new_add_var1->set_persistable(true); } else { - new_add_var->set_persistable(false); + new_add_var1->set_persistable(false); } } + std::string model_str; + dump_program.Proto()->SerializeToString(&model_str); + op_desc->SetAttr("subgraph", model_str); + op_desc->Flush(); + + if (calibration_mode) { + return; + } + + MKDIR("./dump"); + MKDIR(subgraph_root_path.c_str()); + std::ofstream m_stream; + m_stream.open(subgraph_root_path + "/__model__", std::ios::out); + for (auto param_name : params) { auto *var = scope->FindVar(param_name); if (var != nullptr) { auto *var_t = var->GetMutable(); - ofstream p_stream; - p_stream.open(dir_name + "/" + replace_name(param_name, "/", "."), - ios::out); + std::ofstream p_stream; + p_stream.open( + subgraph_root_path + "/" + replace_name(param_name, "/", "."), + std::ios::out); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(var_t->place()); @@ -357,17 +704,8 @@ void DlnneSubgraphPass::CreateDlnneOp( } } - std::string model; - - tmp_dump_program_desc.Proto()->SerializeToString(&model); - m_stream << model; + m_stream << model_str; m_stream.close(); - - op_desc->SetBlockAttr("sub_block", tmp_dump_main_block); - op_desc->SetAttr("subgraph", model); - op_desc->Flush(); - - ConvertGraph(engine_key); } } diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h index 09f9ec0807f92d8e4b2fc6ea51dabf9799e21fdf..ad8d0e07d070fa4ebf8e793d524b0b1b71c75855 100644 --- a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h @@ -34,9 +34,6 @@ class Node; namespace paddle { namespace inference { - -int ConvertGraph(std::string graph_name); - namespace analysis { class DlnneSubgraphPass : public framework::ir::FusePassBase { @@ -44,6 +41,8 @@ class DlnneSubgraphPass : public framework::ir::FusePassBase { void ApplyImpl(framework::ir::Graph *graph) const override; private: + void InferShapeForDlnneMainGraph() const; + bool IsDynamicOp(std::string var_name, bool use_static_batch) const; void CleanIntermediateOutputs(framework::ir::Node *node); void CreateDlnneOp(framework::ir::Node *x, framework::ir::Graph *graph, diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 24925901312605e60cc71bd0db7c8a2b55eda814..9016ab218741e4b3573563ee56fec26832269deb 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -283,6 +283,13 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // Dlnne related CP_MEMBER(use_dlnne_); CP_MEMBER(dlnne_min_subgraph_size_); + CP_MEMBER(dlnne_max_batchsize_); + CP_MEMBER(dlnne_use_static_batch_); + CP_MEMBER(dlnne_weight_share_mode_); + CP_MEMBER(dlnne_use_calib_mode_); + CP_MEMBER(dlnne_precision_mode_); + CP_MEMBER(dlnne_disable_nodes_by_outputs_); + CP_MEMBER(dlnne_input_shape_dict_); // MKLDNN related. CP_MEMBER(use_mkldnn_); CP_MEMBER(mkldnn_enabled_op_types_); @@ -544,9 +551,24 @@ void AnalysisConfig::EnableTensorRtEngine( #endif } -void AnalysisConfig::EnableDlnne(int min_subgraph_size) { +void AnalysisConfig::EnableDlnne( + int min_subgraph_size, + int max_batch_size, + bool use_static_batch, + std::string weight_share_mode, + std::unordered_set disable_nodes_by_ouputs, + std::map> dlnne_input_shape_dict, + bool use_calib_mode, + AnalysisConfig::Precision precision_mode) { use_dlnne_ = true; dlnne_min_subgraph_size_ = min_subgraph_size; + dlnne_max_batchsize_ = max_batch_size; + dlnne_use_static_batch_ = use_static_batch; + dlnne_weight_share_mode_ = weight_share_mode; + dlnne_disable_nodes_by_outputs_ = disable_nodes_by_ouputs; + dlnne_input_shape_dict_ = dlnne_input_shape_dict; + dlnne_use_calib_mode_ = use_calib_mode; + dlnne_precision_mode_ = precision_mode; Update(); } diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f5a51b7c3bc4e77864451362fe4b3dbf6643da29..af8021ea7d8e45d8258499aaf5f21984b7d97bab 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1107,6 +1107,14 @@ void AnalysisPredictor::PrepareArgument() { LOG(INFO) << "Dlnne subgraph is enabled"; argument_.SetUseDlnne(true); argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_); + argument_.SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_); + argument_.SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_); + argument_.SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_); + argument_.SetDlnneDisableNodesByOutputs( + config_.dlnne_disable_nodes_by_outputs_); + argument_.SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_); + argument_.SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_); + argument_.SetDlnnePrecisionMode(config_.dlnne_precision_mode_); } if (config_.lite_engine_enabled()) { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index b925a0c361f94ba059b1868aee9c180dad84e58a..7ec169d7893ba53b619ed227c8f6979e063f5c8f 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -663,7 +663,15 @@ struct PD_INFER_DECL AnalysisConfig { void EnableTensorRtInspector(); bool tensorrt_inspector_enabled() { return trt_use_inspector_; } - void EnableDlnne(int min_subgraph_size = 3); + void EnableDlnne( + int min_subgraph_size = 3, + int max_batch_size = 1, + bool use_static_batch = false, + std::string weight_share_mode = "0", + std::unordered_set disable_nodes_by_outputs = {}, + std::map> input_dict = {}, + bool use_calib_mode = false, + AnalysisConfig::Precision precision_mode = Precision::kFloat32); bool dlnne_enabled() const { return use_dlnne_; } /// @@ -1006,6 +1014,13 @@ struct PD_INFER_DECL AnalysisConfig { // dlnne related. bool use_dlnne_{false}; int dlnne_min_subgraph_size_{3}; + int dlnne_max_batchsize_{1}; + std::unordered_set dlnne_disable_nodes_by_outputs_; + bool dlnne_use_static_batch_{true}; + std::string dlnne_weight_share_mode_; + std::map> dlnne_input_shape_dict_{}; + bool dlnne_use_calib_mode_{false}; + Precision dlnne_precision_mode_{Precision::kFloat32}; // memory reuse related. bool enable_memory_optim_{false}; diff --git a/paddle/fluid/inference/capi/pd_config.cc b/paddle/fluid/inference/capi/pd_config.cc index 45fd2e45c19914220cd85f01e9bd1b67a1404f90..475f0ea23190a818abc0f3d687ca56d5586d18a3 100644 --- a/paddle/fluid/inference/capi/pd_config.cc +++ b/paddle/fluid/inference/capi/pd_config.cc @@ -269,12 +269,28 @@ bool PD_TensorrtEngineEnabled(const PD_AnalysisConfig* config) { return config->config.tensorrt_engine_enabled(); } -void PD_EnableDlnne(PD_AnalysisConfig* config, int min_subgraph_size) { - PADDLE_ENFORCE_NOT_NULL( - config, - paddle::platform::errors::InvalidArgument( - "The pointer of analysis configuration shouldn't be nullptr")); - config->config.EnableDlnne(min_subgraph_size); +void PD_EnableDlnne( + PD_AnalysisConfig* config, + int min_subgraph_size, + int max_batch_size, + bool use_static_batch, + std::string weight_share_mode, + std::unordered_set disable_nodes_by_ouputs, + std::map> dlnne_input_shape_dict, + bool use_calib_mode, + AnalysisConfig::Precision precision_mode) { + PADDLE_ENFORCE_NOT_NULL( + config, + paddle::platform::errors::InvalidArgument( + "The pointer of analysis configuration shouldn't be nullptr")); + config->config.EnableDlnne(min_subgraph_size, + max_batch_size, + use_static_batch, + weight_share_mode, + disable_nodes_by_ouputs, + dlnne_input_shape_dict, + use_calib_mode, + precision_mode); } bool PD_DlnneEnabled(const PD_AnalysisConfig* config) { diff --git a/paddle/fluid/operators/dlnne/CMakeLists.txt b/paddle/fluid/operators/dlnne/CMakeLists.txt index a2aa80f2875b8cfebeec4ef141f2b3819d89a9fc..7c674088c9ab12219dac7104b33b6d622c551f25 100644 --- a/paddle/fluid/operators/dlnne/CMakeLists.txt +++ b/paddle/fluid/operators/dlnne/CMakeLists.txt @@ -9,21 +9,19 @@ endforeach() # add nne find_path( DLNNE_INCLUDE_DIR dlnne.h - PATHS $ENV{SOFTWARE_SOURCE_DIR} $ENV{SOFTWARE_SOURCE_DIR}/driver/nne/include + PATHS $ENV{DL_SDK_DIR} $ENV{DL_SDK_DIR}/include/dlnne NO_DEFAULT_PATH) find_library( DLNNE_LIB libdlnne.so - PATHS $ENV{SOFTWARE_BUILD_DIR} $ENV{SOFTWARE_BUILD_DIR}/driver/nne + PATHS $ENV{DL_SDK_DIR} $ENV{DL_SDK_DIR}/lib NO_DEFAULT_PATH) -find_path(CUDA_INCLUDE_DIR cuda.h - $ENV{SOFTWARE_BUILD_DIR}/llvm-project-10/cuda/include) +find_path(CUDA_INCLUDE_DIR cuda.h $ENV{DL_SDK_DIR}/include) find_library( CURT_LIB libcurt.so - PATHS $ENV{SOFTWARE_BUILD_DIR} - $ENV{SOFTWARE_BUILD_DIR}/llvm-project-10/cuda/lib + PATHS $ENV{DL_SDK_DIR} $ENV{DL_SDK_DIR}/lib NO_DEFAULT_PATH) message("DLNNE_INCLUDE_DIR: "${DLNNE_INCLUDE_DIR}) diff --git a/paddle/fluid/operators/dlnne/dlnne_engine_op.cc b/paddle/fluid/operators/dlnne/dlnne_engine_op.cc index 4654e6a9f978a2885c369c97515aa1c6b1085245..6f57726945034dda59414ad6fcc720f6e5445e42 100644 --- a/paddle/fluid/operators/dlnne/dlnne_engine_op.cc +++ b/paddle/fluid/operators/dlnne/dlnne_engine_op.cc @@ -28,6 +28,105 @@ void CopyTensorCpuToDevice(void* dst_ptr, void* src_ptr, int total_bytes) { cudaDeviceSynchronize(); } +std::string ConvertType(paddle::experimental::DataType type) { + switch (type) { + case paddle::experimental::DataType::FLOAT32: { + return "float32"; + } + case paddle::experimental::DataType::INT64: { + return "int64"; + } + case paddle::experimental::DataType::INT32: { + return "int32"; + } + case paddle::experimental::DataType::FLOAT16: { + return "float16"; + } + default: { + PADDLE_THROW( + platform::errors::Fatal("The DLNNE Calibration only support " + "float/float16/int32_t/int64_t input.")); + } + } +} + +int GetDataByte(paddle::experimental::DataType type) { + switch (type) { + case paddle::experimental::DataType::FLOAT32: { + return 4; + } + case paddle::experimental::DataType::INT64: { + return 8; + } + case paddle::experimental::DataType::INT32: { + return 4; + } + case paddle::experimental::DataType::FLOAT16: { + return 2; + } + default: { + PADDLE_THROW( + platform::errors::Fatal("The DLNNE Calibration only support " + "float/float16/int32_t/int64_t input.")); + } + } +} + +std::string GenerateRandomKey() { + std::string str( + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); + std::random_device rd; + std::mt19937 generator(rd()); + + std::shuffle(str.begin(), str.end(), generator); + return str.substr(0, 32); +} + +void ConvertPaddle2Onnx(std::string onnx_file_name, + std::string subgraph_root_path) { + if (!FileExists(onnx_file_name.c_str())) { + std::stringstream convert_cmd; + convert_cmd << "paddle2onnx --model_dir " << subgraph_root_path + << " --save_file " << onnx_file_name << " --opset_version 11"; + LOG(INFO) << convert_cmd.str(); + int convert_flag = system(convert_cmd.str().c_str()); + PADDLE_ENFORCE_EQ( + convert_flag, + 0, + platform::errors::Unavailable("Convert paddle to onnx failed")); + } +} + +void QuantizeOnnx(std::string onnx_file_name, + std::string rlym_file_name, + std::string quantized_rlym_file_name, + std::string dataset_path, + std::string dataset_plugin_path) { + if (!FileExists(rlym_file_name.c_str())) { + std::stringstream convert_cmd; + convert_cmd << "python -m dl convert " << onnx_file_name + << " --output-model " << rlym_file_name; + LOG(INFO) << convert_cmd.str(); + int convert_flag = system(convert_cmd.str().c_str()); + PADDLE_ENFORCE_EQ( + convert_flag, + 0, + platform::errors::Unavailable("Convert onnx to rlym failed")); + } + + if (!FileExists(quantized_rlym_file_name.c_str())) { + std::stringstream quantize_cmd; + quantize_cmd << "python -m dl quantize " + << "--dataset " << dataset_path << " --plugin " + << dataset_plugin_path << " " << rlym_file_name; + LOG(INFO) << quantize_cmd.str(); + int quantize_flag = system(quantize_cmd.str().c_str()); + PADDLE_ENFORCE_EQ(quantize_flag, + 0, + platform::errors::Unavailable("quantize model failed")); + } +} + } // namespace inference namespace operators { @@ -41,7 +140,23 @@ class DlnneEngineOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "engine_key", "The engine_key here is used to distinguish different DLNNE Engines"); - AddAttr("sub_block", "the trt block"); + AddAttr("max_batch_size", "engine max_batch_size"); + AddAttr("use_static_batch", "static batch fix for [?,H,W,C]"); + AddAttr("weight_share_mode", + "dlnne weight_share_mode, can be '0', '1', '2', '3', " + "'01', '23', '0123' "); + // when use_calib_mode is true and enable_int8 is true, + // the calibration_runtime start, + // when calibration_mode is true, the calibration_runtiime + // go to the first stage of calibration, and when finish + // fisrt stage, the calibration_mode is set false, the + // calibration_runtime go to the second stage + AddAttr("use_calib_mode", "dlnne use calib mode"); + AddAttr("enable_int8", "dlnne enable int8"); + AddAttr("calibration_mode", "dlnne calibration_mode"); + AddAttr("calibration_data_path", "calibration data path"); + AddAttr("subgraph_root_path", "subgraph root path"); + AddAttr("sub_block", "the dlnne block"); AddComment("Dlnne engine operator."); } }; diff --git a/paddle/fluid/operators/dlnne/dlnne_engine_op.h b/paddle/fluid/operators/dlnne/dlnne_engine_op.h index 591dab0b77a018237d4cac74eb5461f73a099838..7a925391eb962d74599ed719d0b17353f5a00dac 100644 --- a/paddle/fluid/operators/dlnne/dlnne_engine_op.h +++ b/paddle/fluid/operators/dlnne/dlnne_engine_op.h @@ -13,25 +13,38 @@ // limitations under the License. #pragma once -#include #include // NOTLINT #include // NOTLINT #include // NOTLINT +#include #include #include #include #include +#include +#include #include #include #include #include #include +#include "paddle/fluid/framework/data_device_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/utils/io_utils.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/core/ddim.h" + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" namespace dl { namespace nne { @@ -40,6 +53,31 @@ class Engine; class Network; class Parser; class ExecutionContext; + +inline unsigned int GetElementSize(DataType type) { + switch (type) { + case DataType::kINT64: + case DataType::kUINT64: + case DataType::kFLOAT64: + return 8; + case DataType::kINT32: + case DataType::kUINT32: + case DataType::kFLOAT32: + return 4; + case DataType::kINT16: + case DataType::kUINT16: + case DataType::kFLOAT16: + return 2; + case DataType::kINT8: + case DataType::kUINT8: + case DataType::kBOOL: + return 1; + case DataType::kUNKNOWN_TYPE: + return 0; + } + return 0; +} + } // namespace nne } // namespace dl @@ -61,8 +99,45 @@ void CopyTensorDeviceToCpu(void *dst_ptr, void *src_ptr, int total_bytes); void CopyTensorCpuToDevice(void *dst_ptr, void *src_ptr, int total_bytes); -template -struct Singleton; +std::string ConvertType(paddle::experimental::DataType type); + +int GetDataByte(paddle::experimental::DataType type); + +std::string GenerateRandomKey(); + +void ConvertPaddle2Onnx(std::string onnx_file_name, + std::string subgraph_root_path); + +void QuantizeOnnx(std::string onnx_file_name, + std::string rlym_file_name, + std::string quantized_rlym_file_name, + std::string dataset_path, + std::string dataset_plugin_path); + +static paddle::experimental::DataType DLNNE2FluidDataType( + dl::nne::DataType type) { + switch (type) { + case dl::nne::DataType::kFLOAT32: + return paddle::experimental::DataType::FLOAT32; + case dl::nne::DataType::kINT32: + return paddle::experimental::DataType::INT32; + case dl::nne::DataType::kINT64: + return paddle::experimental::DataType::INT64; + case dl::nne::DataType::kFLOAT16: + return paddle::experimental::DataType::FLOAT16; + case dl::nne::DataType::kUINT8: + return paddle::experimental::DataType::UINT8; + case dl::nne::DataType::kINT8: + return paddle::experimental::DataType::INT8; + case dl::nne::DataType::kBOOL: + return paddle::experimental::DataType::BOOL; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "unknown fluid datatype in Fluid op converter")); + return paddle::experimental::DataType::FLOAT32; + } +} + } // namespace inference } // namespace paddle @@ -70,15 +145,26 @@ namespace paddle { namespace operators { +std::mutex static dlnne_create_lock; + class DlnneEngineOp : public framework::OperatorBase { private: std::vector input_names_; std::unordered_set param_names_; std::string engine_key_; + bool use_static_batch_; + bool calibration_mode_; + std::string calibration_data_path_; + std::string subgraph_root_path_; + bool enable_int8_; + bool use_calib_mode_; + + std::string weight_share_mode_; + int max_batch_size_; int num_inputs; int num_outputs; - std::vector output_names; - std::vector input_names; + // std::vector output_names; + // std::vector input_names; dl::nne::Builder *builder; dl::nne::Parser *parser; @@ -89,6 +175,10 @@ class DlnneEngineOp : public framework::OperatorBase { unsigned int engine_input_size; std::vector InputIndexToBindIndex_; + char *dump_flag_; + char *dlnne_log_flag_; + char *dl_sdk_dir_; + public: DlnneEngineOp(const std::string &type, const framework::VariableNameMap &inputs, @@ -97,81 +187,214 @@ class DlnneEngineOp : public framework::OperatorBase { : framework::OperatorBase(type, inputs, outputs, attrs) { input_names_ = Inputs("Xs"); engine_key_ = Attr("engine_key"); + use_static_batch_ = Attr("use_static_batch"); + max_batch_size_ = Attr("max_batch_size"); + weight_share_mode_ = Attr("weight_share_mode"); + calibration_mode_ = Attr("calibration_mode"); + calibration_data_path_ = Attr("calibration_data_path"); + subgraph_root_path_ = Attr("subgraph_root_path"); + enable_int8_ = Attr("enable_int8"); + use_calib_mode_ = Attr("use_calib_mode"); + + // dump input/output buffer of dlnne engine + dump_flag_ = getenv("PADDLE_DUMP_DLNNE_BUFFER"); + dlnne_log_flag_ = getenv("PADDLE_DLNNE_LOG"); + dl_sdk_dir_ = getenv("DL_SDK_DIR"); + auto params = Attr>("parameters"); for (const auto ¶m : params) { param_names_.insert(param); } - num_inputs = 0; + std::vector XsMap; + num_inputs = Inputs("Xs").size(); + std::string valid_input_name_str = Attr("valid_input_names"); + for (const auto &x : Inputs("Xs")) { - if (param_names_.count(x)) continue; - num_inputs += 1; - input_names.push_back(x); + // input_names.push_back(x); + XsMap.push_back( + valid_input_name_str.substr(0, valid_input_name_str.find(","))); + valid_input_name_str = + valid_input_name_str.substr(valid_input_name_str.find(",") + 1); } + std::vector YsMap; num_outputs = Outputs("Ys").size(); + std::string valid_output_name_str = Attr("valid_output_names"); for (const auto &y : Outputs("Ys")) { - VLOG(4) << "y: " << y << std::endl; - output_names.push_back(y); + // output_names.push_back(y); + YsMap.push_back( + valid_output_name_str.substr(0, valid_output_name_str.find(","))); + valid_output_name_str = + valid_output_name_str.substr(valid_output_name_str.find(",") + 1); } - // onnx path - std::stringstream filename; - std::string current_path = "."; - char *buffer; - if ((buffer = getcwd(NULL, 0)) != NULL) { - current_path = buffer; - } else { - current_path = "."; - } - filename << current_path << "/dump/" << engine_key_ << "/" << engine_key_ - << ".onnx"; - - builder = dl::nne::CreateInferBuilder(); - PADDLE_ENFORCE_NE( - builder, - nullptr, - platform::errors::Unavailable("nne create builder failed")); - parser = dl::nne::CreateParser(); - PADDLE_ENFORCE_NE( - parser, - nullptr, - platform::errors::Unavailable("nne create parser failed")); - - network = builder->CreateNetwork(); - - LOG(INFO) << "set output for dlnne"; - for (std::string &output_op_name : output_names) - parser->RegisterOutput(output_op_name.c_str()); - - LOG(INFO) << "parser onnx for dlnne"; - parser->Parse(filename.str().c_str(), *network); - - LOG(INFO) << "build network"; - engine = builder->BuildEngine(*network); - - // total size = input_size+output_size - engine_input_size = num_inputs + num_outputs; - for (std::string &input_name : input_names) { - int BindIndex = engine->GetBindingIndex(input_name.c_str()); - InputIndexToBindIndex_.push_back(BindIndex); - } + // TODO(pei.jiang): add dlnne_engine manager to manage dlnne_engine + if (!calibration_mode_) { + std::map weight_share_map; + weight_share_map.insert( + std::make_pair("0", dl::nne::WeightShareMode::kSingle)); + weight_share_map.insert( + std::make_pair("1", dl::nne::WeightShareMode::kSingle)); + weight_share_map.insert( + std::make_pair("2", dl::nne::WeightShareMode::kSingle)); + weight_share_map.insert( + std::make_pair("3", dl::nne::WeightShareMode::kSingle)); + weight_share_map.insert( + std::make_pair("01", dl::nne::WeightShareMode::kShare2)); + weight_share_map.insert( + std::make_pair("23", dl::nne::WeightShareMode::kShare2)); + weight_share_map.insert( + std::make_pair("0123", dl::nne::WeightShareMode::kShare4)); + + std::map cluster_config_map; + cluster_config_map.insert( + std::make_pair("0", dl::nne::ClusterConfig::kCluster0)); + cluster_config_map.insert( + std::make_pair("1", dl::nne::ClusterConfig::kCluster1)); + cluster_config_map.insert( + std::make_pair("2", dl::nne::ClusterConfig::kCluster2)); + cluster_config_map.insert( + std::make_pair("3", dl::nne::ClusterConfig::kCluster3)); + cluster_config_map.insert( + std::make_pair("01", dl::nne::ClusterConfig::kCluster01)); + cluster_config_map.insert( + std::make_pair("23", dl::nne::ClusterConfig::kCluster23)); + cluster_config_map.insert( + std::make_pair("0123", dl::nne::ClusterConfig::kCluster0123)); + + dl::nne::WeightShareMode mode = weight_share_map[weight_share_mode_]; + dl::nne::ClusterConfig cluster_config = + cluster_config_map[weight_share_mode_]; + if (dlnne_log_flag_) { + LOG(INFO) << "weight_share_mode: " << mode + << " cluster_config: " << cluster_config; + } - for (std::string &output_name : output_names) { - int BindIndex = engine->GetBindingIndex(output_name.c_str()); - InputIndexToBindIndex_.push_back(BindIndex); - } + std::string onnx_file_name = + subgraph_root_path_ + "/" + engine_key_ + ".onnx"; + inference::ConvertPaddle2Onnx(onnx_file_name, subgraph_root_path_); + + std::string rlym_file_name = + subgraph_root_path_ + "/" + engine_key_ + ".rlym"; + // quantize don't support set quantized ouput model path now, + // the quantized model file is in current dir + std::string quantized_rlym_file_name = engine_key_ + ".quantized.rlym"; + + std::stringstream filename; + std::stringstream engine_file_name; + + if (enable_int8_ && use_calib_mode_) { + std::string dataset_path = calibration_data_path_; + std::string cnt_dataset_path = dataset_path + "/" + input_names_[0]; + + std::stringstream dataset_plugin_path; + dataset_plugin_path << dl_sdk_dir_ + << "/python/dleol/quantize/plugin.py"; + + inference::QuantizeOnnx(onnx_file_name, + rlym_file_name, + quantized_rlym_file_name, + dataset_path, + dataset_plugin_path.str()); + + filename << quantized_rlym_file_name; + engine_file_name << subgraph_root_path_ << "/" << engine_key_ + << "_quantized" + << "_ws_" << weight_share_mode_ << ".engine"; + } else { + filename << onnx_file_name; + engine_file_name << subgraph_root_path_ << "/" << engine_key_ << "_ws_" + << weight_share_mode_ << ".engine"; + } + + dlnne_create_lock.lock(); + if (dlnne_log_flag_) { + LOG(INFO) << "EngineKey:" << engine_key_ + << " use_static_batch_:" << use_static_batch_ + << " max_batch_size_:" << max_batch_size_ + << " weight_share_mode_: " << weight_share_mode_; + } + + builder = dl::nne::CreateInferBuilder(); + PADDLE_ENFORCE_NE( + builder, + nullptr, + platform::errors::Unavailable("nne create builder failed")); + dl::nne::BuilderConfig builder_cfg; + builder_cfg.max_batch_size = max_batch_size_; + builder_cfg.ws_mode = weight_share_map[weight_share_mode_]; + builder->SetBuilderConfig(builder_cfg); + network = builder->CreateNetwork(); + + parser = dl::nne::CreateParser(); + PADDLE_ENFORCE_NE( + parser, + nullptr, + platform::errors::Unavailable("nne create parser failed")); + if (dlnne_log_flag_) { + LOG(INFO) << "set output for dlnne"; + } + for (std::string &output_op_name : YsMap) { + parser->RegisterOutput(output_op_name.c_str()); + if (dlnne_log_flag_) { + LOG(INFO) << output_op_name; + } + } + + std::fstream engine_file; + engine_file.open(engine_file_name.str().c_str(), std::ios::in); + if (!engine_file) { + if (dlnne_log_flag_) { + LOG(INFO) << "parser model file for dlnne"; + } + parser->Parse(filename.str().c_str(), *network); + if (dlnne_log_flag_) { + LOG(INFO) << "build network"; + } + engine = builder->BuildEngine(*network); + + auto memory = engine->Serialize(); + std::ofstream out(engine_file_name.str().c_str(), + std::ofstream::binary); + out.write(reinterpret_cast(memory->Data()), memory->Size()); + out.close(); + memory->Destroy(); + } else { + engine_file.seekg(0, std::ios::end); + uint64_t length = static_cast(engine_file.tellg()); + engine_file.seekg(0, std::ios::beg); + char *slz_data = new char[length]; + engine_file.read(slz_data, static_cast(length)); + engine = dl::nne::Deserialize(slz_data, length); + delete[] slz_data; + } - // context - context = engine->CreateExecutionContext(); + engine_input_size = num_inputs + num_outputs; + for (std::string &input_name : XsMap) { + int BindIndex = engine->GetBindingIndex(input_name.c_str()); + InputIndexToBindIndex_.push_back(BindIndex); + } + for (std::string &output_name : YsMap) { + int BindIndex = engine->GetBindingIndex(output_name.c_str()); + InputIndexToBindIndex_.push_back(BindIndex); + } + + // context + context = engine->CreateExecutionContext( + cluster_config_map[weight_share_mode_]); + dlnne_create_lock.unlock(); + } } ~DlnneEngineOp() { - network->Destroy(); - context->Destroy(); - engine->Destroy(); - parser->Destroy(); - builder->Destroy(); + if (!calibration_mode_) { + network->Destroy(); + context->Destroy(); + engine->Destroy(); + parser->Destroy(); + builder->Destroy(); + } } protected: @@ -190,7 +413,42 @@ class DlnneEngineOp : public framework::OperatorBase { std::vector input_data_types(num_inputs); std::vector input_bytes(num_inputs); + dlnne_create_lock.lock(); int index = 0; + int infer_batch = 1; + std::vector vec_infer_batch; + // compute infer_batch + if (use_static_batch_) { + for (const auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + // convert input and copy to Dlnne engine's buffer + auto &t = + inference::analysis::GetFromScope(scope, x); + + auto t_shape = phi::vectorize(t.dims()); + std::vector runtime_input_shape(t_shape.begin(), + t_shape.end()); + const int bind_index = index; + index++; + dl::nne::Dims in_dim = engine->GetBindingDimensions(bind_index); + + int compute_batch = runtime_input_shape[0] / in_dim.d[0]; + VLOG(4) << "compute batch: " << compute_batch; + vec_infer_batch.push_back(compute_batch); + } + + int first_batch = vec_infer_batch[0]; + for (auto batch : vec_infer_batch) { + PADDLE_ENFORCE_EQ( + first_batch, + batch, + platform::errors::Unavailable( + "compute infer_batchs is different from each other")); + } + infer_batch = first_batch; + } + + index = 0; for (const auto &x : Inputs("Xs")) { if (param_names_.count(x)) continue; // convert input and copy to Dlnne engine's buffer @@ -199,26 +457,33 @@ class DlnneEngineOp : public framework::OperatorBase { const int bind_index = index; index++; - int64_t data_bytes; + int64_t data_bytes, ele_num; int32_t dtype; - auto type = framework::TransToProtoVarType(t.dtype()); + auto type = t.type(); data_bytes = 1; + ele_num = 1; void *buffer = nullptr; - if (type == framework::proto::VarType::FP32) { + // TODO(pei.jiang): add more type + if (type == paddle::experimental::DataType::FLOAT32) { buffer = static_cast(t.data()); data_bytes = 4; dtype = 0; - } else if (type == framework::proto::VarType::INT64) { + } else if (type == paddle::experimental::DataType::INT64) { buffer = static_cast(t.data()); data_bytes = 8; dtype = 1; - } else if (type == framework::proto::VarType::INT32) { + } else if (type == paddle::experimental::DataType::INT32) { buffer = static_cast(t.data()); data_bytes = 4; dtype = 2; + } else if (type == paddle::experimental::DataType::FLOAT16) { + buffer = static_cast(t.data()); + data_bytes = 2; + dtype = 3; } else { - PADDLE_THROW(platform::errors::Fatal( - "The DLNNE Engine OP only support float/int32_t/int64_t input.")); + PADDLE_THROW( + platform::errors::Fatal("The DLNNE Engine OP only support " + "float/int32_t/int64_t/float16 input.")); } input_buffers[bind_index] = buffer; @@ -226,6 +491,7 @@ class DlnneEngineOp : public framework::OperatorBase { std::vector runtime_input_shape(t_shape.begin(), t_shape.end()); for (auto &size : t_shape) { data_bytes = data_bytes * size; + ele_num = ele_num * size; } VLOG(4) << "buffers_size:" << data_bytes; @@ -234,35 +500,59 @@ class DlnneEngineOp : public framework::OperatorBase { input_shapes[bind_index] = runtime_input_shape; input_data_types[bind_index] = dtype; input_bytes[bind_index] = data_bytes; + + if (dump_flag_) { + std::stringstream dump_input_name; + dump_input_name << engine_key_ << "_input_" << bind_index << ".txt"; + std::ofstream dump_input_file; + dump_input_file.open(dump_input_name.str()); + for (int64_t i = 0; i < ele_num; i++) { + dump_input_file << static_cast( + cpu_input_buffers[bind_index])[i] + << "\n"; + } + dump_input_file << "\b"; + dump_input_file.close(); + } } // output shape std::vector> out_shapes; + std::vector out_types; + std::vector out_ele_nums; std::vector output_bytes; for (int i = 0; i < num_outputs; i++) { - int index = engine->GetBindingIndex(output_names[i].c_str()); + int index = InputIndexToBindIndex_[i + num_inputs]; + dl::nne::DataType out_type = engine->GetBindingDataType(index); + out_types.push_back(out_type); dl::nne::Dims out_dim = engine->GetBindingDimensions(index); std::vector shape(out_dim.nbDims); for (int dim = 0; dim < out_dim.nbDims; dim++) { - shape[dim] = (out_dim.d[dim]); + if (use_static_batch_ && dim == 0) { + shape[dim] = (out_dim.d[dim]) * infer_batch; + } else { + shape[dim] = (out_dim.d[dim]); + } } out_shapes.push_back(shape); - int64_t data_bytes; + int64_t data_bytes, out_ele_num; + out_ele_num = 1; // float32 - data_bytes = 4; + data_bytes = dl::nne::GetElementSize(out_type); for (auto &size : shape) { data_bytes = data_bytes * size; + out_ele_num = out_ele_num * size; } VLOG(4) << "data_bytes: " << data_bytes; output_bytes.push_back(data_bytes); + out_ele_nums.push_back(out_ele_num); } int bind_index = 0; std::vector cpu_output_buffers(num_outputs); std::vector output_buffers(num_outputs); - std::vector output_dtypes(num_outputs); for (const auto &y : Outputs("Ys")) { auto *fluid_v = scope.FindVar(y); @@ -273,15 +563,19 @@ class DlnneEngineOp : public framework::OperatorBase { auto *fluid_t = fluid_v->GetMutable(); - VLOG(4) << "out_shapes[bind_index] dim:" << out_shapes[bind_index].size(); + VLOG(4) << bind_index << ": out_shapes[bind_index] dim:" + << out_shapes[bind_index].size(); fluid_t->Resize(phi::make_ddim(out_shapes[bind_index])); - int32_t dtype; - output_buffers[bind_index] = fluid_t->mutable_data(dev_place); - dtype = 0; + dl::nne::DataType dl_type = out_types[bind_index]; + if (dlnne_log_flag_) { + LOG(INFO) << "output type: " << dl_type; + } + output_buffers[bind_index] = static_cast(fluid_t->mutable_data( + dev_place, inference::DLNNE2FluidDataType(dl_type))); + cpu_output_buffers[bind_index] = output_buffers[bind_index]; // malloc(data_bytes); - output_dtypes[bind_index] = dtype; bind_index++; } @@ -289,7 +583,9 @@ class DlnneEngineOp : public framework::OperatorBase { // set input_ptr for (unsigned int i = 0; i < engine_input_size; i++) { - if (InputIndexToBindIndex_[i] < 0) continue; + if (InputIndexToBindIndex_[i] < 0) { + continue; + } if (engine->BindingIsInput(InputIndexToBindIndex_[i])) { // copy cpu buffer to gpu buffer @@ -308,7 +604,7 @@ class DlnneEngineOp : public framework::OperatorBase { } else { int64_t total_size; - total_size = output_bytes[i - input_names.size()]; + total_size = output_bytes[i - input_names_.size()]; VLOG(4) << "output_bytes: " << total_size; void *gpu_ptr; cudaMalloc(&gpu_ptr, total_size); @@ -318,36 +614,142 @@ class DlnneEngineOp : public framework::OperatorBase { clock_t startTime, endTime; startTime = clock(); - context->Execute(1, engine_input_ptr.data()); + context->Execute(infer_batch, engine_input_ptr.data()); endTime = clock(); - double during_ms = - static_cast(endTime - startTime) / CLOCKS_PER_SEC * 1000; - LOG(INFO) << "dlNNE execute time: " << during_ms << " ms"; + + if (dlnne_log_flag_) { + double during_ms = + static_cast(endTime - startTime) / CLOCKS_PER_SEC * 1000; + LOG(INFO) << "dlNNE execute time: " << during_ms << " ms"; + } bind_index = 0; for (unsigned int i = 0; i < engine_input_size; i++) { if (InputIndexToBindIndex_[i] < 0) continue; - if (i >= input_names.size()) { - void *cpu_ptr = cpu_output_buffers[i - input_names.size()]; + if (i >= input_names_.size()) { + void *cpu_ptr = cpu_output_buffers[i - input_names_.size()]; int64_t size; - size = output_bytes[i - input_names.size()]; + size = output_bytes[i - input_names_.size()]; paddle::inference::CopyTensorDeviceToCpu( cpu_ptr, engine_input_ptr[InputIndexToBindIndex_[i]], size); - // dtype: float32 - int32_t dtypes; - dtypes = 0; cpu_output_buffers[bind_index] = cpu_ptr; - output_dtypes[bind_index] = dtypes; + + if (dump_flag_) { + std::stringstream dump_output_name; + dump_output_name << engine_key_ << "_output_" << bind_index << ".txt"; + std::ofstream dump_output_file; + dump_output_file.open(dump_output_name.str()); + for (int64_t i = 0; i < out_ele_nums[bind_index]; i++) { + dump_output_file + << static_cast(cpu_output_buffers[bind_index])[i] + << "\n"; + } + dump_output_file << "\b"; + dump_output_file.close(); + } bind_index++; } cudaFree(engine_input_ptr[InputIndexToBindIndex_[i]]); } + dlnne_create_lock.unlock(); + } + + void RunNativeImpl(const framework::Scope &scope, + const platform::Place &dev_place) const { + VLOG(4) << "RunNativeImpl"; + framework::Executor executor(dev_place); + auto *block = Attr("sub_block"); + auto *program = block->Program(); + auto ¤t_scope = scope.NewScope(); + auto ctx = executor.Prepare(*program, block->ID()); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); + } + + void RunCalibration(const framework::Scope &scope, + const platform::Place &dev_place) const { + std::unordered_map calib_data_map; + std::unordered_map> calib_data_shape_map; + std::unordered_map calib_data_type_map; + std::unordered_map calib_buffer_size_map; + + for (auto &x : Inputs("Xs")) { + if (param_names_.count(x)) continue; + auto &t = + inference::analysis::GetFromScope(scope, x); + calib_data_map.emplace(x, t.data()); + + // TODO(pei.jiang): refine this code, because when run dlnne create + // engine, there is same code + auto t_shape = phi::vectorize(t.dims()); + std::vector input_shape(t_shape.begin(), t_shape.end()); + calib_data_shape_map.emplace(x, input_shape); + std::string data_type = inference::ConvertType(t.type()); + calib_data_type_map.emplace(x, data_type); + + int data_bytes = inference::GetDataByte(t.type()); + VLOG(4) << "input name: " << x << ", data_type: " << data_type; + VLOG(4) << "data shape: "; + int64_t buffer_size = data_bytes; + for (auto dim : input_shape) { + buffer_size *= dim; + VLOG(4) << dim; + } + VLOG(4) << "buffer_size: " << buffer_size; + calib_buffer_size_map.emplace(x, buffer_size); + } + + std::string random_key = inference::GenerateRandomKey(); + for (auto calib_data : calib_data_map) { + std::string input_name = calib_data.first; + std::string input_data_path = calibration_data_path_ + "/" + input_name; + MKDIR(input_data_path.c_str()); + + std::string input_data_item_path = + input_data_path + "/" + random_key + ".binary"; + auto outfile = std::fstream(input_data_item_path.c_str(), + std::ios::out | std::ios::binary); + int64_t buffer_size = calib_buffer_size_map[input_name]; + outfile.write(reinterpret_cast(calib_data.second), buffer_size); + outfile.close(); + } + + std::stringstream calib_config_ss; + calib_config_ss << "shape message: " << std::endl; + for (auto const &shape_item : calib_data_shape_map) { + calib_config_ss << shape_item.first << ":"; + for (auto const &dim : shape_item.second) { + calib_config_ss << dim << " "; + } + calib_config_ss << std::endl; + } + + calib_config_ss << "dtype message: " << std::endl; + for (auto const &dtype_item : calib_data_type_map) { + calib_config_ss << dtype_item.first << ":" << dtype_item.second + << std::endl; + } + + std::ofstream calib_config_file; + std::string calib_config_path = + calibration_data_path_ + "/calib_config.txt"; + calib_config_file.open(calib_config_path); + calib_config_file << calib_config_ss.str(); + calib_config_file.close(); + + RunNativeImpl(scope, dev_place); } void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { + VLOG(4) << "calibration_mode_: " << calibration_mode_; + if (calibration_mode_ == true) { + VLOG(4) << "RunCalibration"; + RunCalibration(scope, dev_place); + return; + } + RunDlnneOnCreateEngine(scope, dev_place); } }; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 14975ac337aed61f355d6d60d02d54fb43d6d80e..20af07497c6d0f9e59d044571a1eafead8f3c59e 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -730,7 +730,16 @@ void BindAnalysisConfig(py::module *m) { .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("enable_dlnne", &AnalysisConfig::EnableDlnne, - py::arg("min_subgraph_size") = 3) + py::arg("min_subgraph_size") = 3, + py::arg("max_batch_size") = 1, + py::arg("use_static_batch") = false, + py::arg("weight_share_mode") = "0", + py::arg("disable_nodes_by_outputs") = + std::unordered_set(), + py::arg("input_shape_dict") = + std::map>(), + py::arg("use_calib_mode") = false, + py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,