diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index ca8ab8aa71ef6bb79e7a546497dc4760d0a4857e..efaf79d48b3f6e975b0f27ddd41135ccd3087c4b 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -14,7 +14,10 @@ #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" +#include +#include #include +#include #include #include "paddle/fluid/framework/block_desc.h" @@ -39,7 +42,106 @@ namespace analysis { namespace { -bool IsKernelSupportPrecision( +inline std::string SerializeParams(framework::Scope* scope, + const std::vector& params) { + std::ostringstream os; + phi::CPUContext ctx; + for (const auto& param : params) { + VLOG(3) << "Serialize param: " << param; + PADDLE_ENFORCE_NOT_NULL( + scope->FindVar(param), + platform::errors::NotFound("Block should already have a '%s' variable", + param)); + auto* tensor = scope->FindVar(param)->GetMutable(); + framework::SerializeToStream(os, *tensor, ctx); + } + return os.str(); +} + +inline void StrToBinary(const std::string& path, const std::string& str) { + std::ofstream file(path.c_str(), std::ios::binary); + file.write(str.c_str(), str.size()); + file.close(); +} +inline bool NodeVarHasDtype(framework::ir::Node* node) { + if (node->IsCtrlVar()) return false; + + if (node->IsVar() && + (node->Var()->GetType() == + paddle::framework::proto::VarType::SELECTED_ROWS || + node->Var()->GetType() == + paddle::framework::proto::VarType::LOD_TENSOR || + node->Var()->GetType() == + paddle::framework::proto::VarType::LOD_TENSOR_ARRAY || + node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS || + node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) { + return true; + } + + return false; +} +void SaveMixedModel(framework::ir::Graph* graph, + framework::Scope* scope, + framework::ProgramDesc* mixed_program_desc, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision) { + paddle::CPUPlace place; + auto parameters = scope->LocalVarNames(); + std::sort(parameters.begin(), parameters.end()); + + std::unordered_set weights_should_be_fp32; + for (auto* node : graph->Nodes()) { + if (!(node->IsVar() && !node->IsCtrlVar())) continue; + if (NodeVarHasDtype(node)) { + if (node->Var()->Persistable() && + node->Var()->GetDataType() == + paddle::framework::proto::VarType::FP32) { + VLOG(2) << "weights keep to fp32: " << node->Name(); + weights_should_be_fp32.insert(node->Name()); + } + } + } + + for (const auto& param_name : parameters) { + auto* var = scope->FindLocalVar(param_name); + if (var->IsType() || + var->IsType()) { + auto* t = var->GetMutable(); + framework::Tensor mixed_tensor; + mixed_tensor.Resize(t->dims()); + auto* data = t->mutable_data(platform::CPUPlace()); + + if (mixed_precision == phi::DataType::FLOAT16 && + !weights_should_be_fp32.count(param_name)) { + mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16); + auto* mixed_data = + mixed_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + mixed_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(mixed_tensor, place, t); + } else if (mixed_precision == phi::DataType::BFLOAT16 && + !weights_should_be_fp32.count(param_name)) { + mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16); + auto* mixed_data = + mixed_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + mixed_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(mixed_tensor, place, t); + } + } + } + + StrToBinary(mixed_model_file, + mixed_program_desc->Proto()->SerializeAsString()); + StrToBinary(mixed_params_file, SerializeParams(scope, parameters)); +} + +bool PhiKernelSupportPrecision( const std::string& op_type, phi::Backend backend, phi::DataType data_type, @@ -56,10 +158,23 @@ bool GpuKernelSupportPrecision( const std::string& op_type, phi::DataType data_type, phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { - bool res = - IsKernelSupportPrecision(op_type, phi::Backend::GPU, data_type, layout); - res |= IsKernelSupportPrecision( - op_type, phi::Backend::GPUDNN, data_type, layout); + auto phi_op_type = phi::TransToPhiKernelName(op_type); + bool res = PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPU, data_type, layout); + res |= PhiKernelSupportPrecision( + phi_op_type, phi::Backend::GPUDNN, data_type, layout); + + if (!res) { + auto& all_kernels = OperatorWithKernel::AllOpKernels(); + auto it = all_kernels.find(op_type); + if (it != all_kernels.end()) { + for (auto& kern_pair : it->second) { + if (platform::is_gpu_place(kern_pair.first.place_)) { + res = true; + } + } + } + } return res; } @@ -90,30 +205,16 @@ bool OutShouldNotConvert(ir::Node* var_node) { return false; } - -// Get weight names which appear in multiple block (block 0 and block n). -std::unordered_set GetMultiBlockPersistableNames( - framework::ProgramDesc* program_desc) { - std::unordered_set special_weights; - size_t block_size = program_desc->Size(); - - std::unordered_set block_0_weights; - for (auto var : program_desc->Block(0).AllVars()) { - if (var->Persistable()) block_0_weights.insert(var->Name()); - } - - for (size_t i = 1; i < block_size; ++i) { - // std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() << - // std::endl;; - auto all_ops = program_desc->Block(i).AllOps(); - for (auto op : all_ops) { - for (auto name : op->InputArgumentNames()) { - if (block_0_weights.count(name)) special_weights.insert(name); - } - } +void ProcessOutputNode(ir::Node* var_node, + framework::proto::VarType::Type to_type) { + if (!NodeVarHasDtype(var_node)) return; + auto* out_var = var_node->Var(); + if (out_var->GetDataType() == framework::proto::VarType::FP32) { + if (OutShouldNotConvert(var_node)) return; + out_var->SetDataType(to_type); } - - return special_weights; + VLOG(3) << " out_node name " << var_node->Name() << " data_type " + << out_var->GetDataType(); } // Just process special cases for weights conversion. @@ -143,21 +244,8 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { } } - // If cur_op's next is condition_flow op, then cur op should be fp32. Note, we - // now only convert to mixed in block 0. - for (auto* op_node : op_nodes) { - for (auto var : op_node->outputs) { - for (auto next_op : var->outputs) { - if (next_op->Op()->HasAttr("sub_block")) { - return true; - } - } - } - } - return false; } - inline bool IsFloatVarType(framework::proto::VarType::Type type) { if (type == framework::proto::VarType::FP16 || type == framework::proto::VarType::FP32 || @@ -165,6 +253,56 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) { return true; return false; } +void ProcessInputNode( + bool support_precision, + framework::ir::Graph* graph, + ir::Node* in_node, + ir::Node* op_node, + int* suffix, + framework::BlockDesc* block_desc, + std::unordered_map* cast_map, + framework::proto::VarType::Type to_type, + bool is_main_block, + std::unordered_map* + vars_in_multi_block_map) { + if (!NodeVarHasDtype(in_node)) return; + auto* in_var = in_node->Var(); + auto in_var_type = in_var->GetDataType(); + if (!is_main_block && vars_in_multi_block_map->count(in_var->Name())) { + in_var_type = vars_in_multi_block_map->at(in_var->Name()); + } + if (support_precision) { + if (in_var->Persistable() && + in_var_type == framework::proto::VarType::FP32) { + if (WeightsShouldNotConvert(in_node)) return; + in_var->SetDataType(to_type); + } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) && + in_var_type != to_type) { + AddCastOp(graph, + in_node, + op_node, + in_var_type, + to_type, + suffix, + block_desc, + cast_map); + } + } else { + if (!in_var->Persistable() && IsFloatVarType(in_var_type) && + in_var_type != to_type) { + AddCastOp(graph, + in_node, + op_node, + in_var_type, + to_type, + suffix, + block_desc, + cast_map); + } + } + VLOG(3) << " in_node name " << in_var->Name() << " data_type " + << in_var->GetDataType(); +} void ConvertAllFp64ToFp32(framework::ir::Graph* graph) { auto op_nodes = framework::ir::TopologySortOperations(*graph); @@ -239,6 +377,11 @@ void HandleSpecialOps(framework::OpDesc* op_desc) { static_cast(framework::proto::VarType::FP32)) op_desc->SetAttr("dtype", static_cast(framework::proto::VarType::FP16)); + } else if (op_desc->Type() == "fill_constant_batch_size_like") { + if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) == + static_cast(framework::proto::VarType::FP32)) + op_desc->SetAttr("dtype", + static_cast(framework::proto::VarType::FP16)); } } @@ -260,26 +403,47 @@ void FixCastAttr(framework::ir::Graph* graph) { } } -// If op's output var is condition flow op's input, then the op must be fp32 -// precision. -bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) { - auto cur_op_outs = cur_op_node->outputs; - for (auto out_var : cur_op_outs) { - for (auto next_op_node : out_var->outputs) { - if (next_op_node->Op()->HasAttr("sub_block")) { - return true; - } +void FindVarsInMultiBlock( + framework::ProgramDesc* program_desc, + std::unordered_map* + vars_in_multi_block_map) { + std::set vars_in_multi_block; + std::set main_block_var_names_set; + for (auto op : program_desc->Block(0).AllOps()) { + auto in_names = op->InputArgumentNames(); + main_block_var_names_set.insert(in_names.begin(), in_names.end()); + } + + for (size_t i = 1; i < program_desc->Size(); ++i) { + std::set block_var_names_set; + for (auto op : program_desc->Block(i).AllOps()) { + auto in_names = op->InputArgumentNames(); + block_var_names_set.insert(in_names.begin(), in_names.end()); } + + std::set_intersection( + main_block_var_names_set.begin(), + main_block_var_names_set.end(), + block_var_names_set.begin(), + block_var_names_set.end(), + std::inserter(vars_in_multi_block, vars_in_multi_block.begin())); + } + + for (auto name : vars_in_multi_block) { + vars_in_multi_block_map->emplace(name, framework::proto::VarType::FP32); } - return false; } -void ConvertTensorDtype(framework::ProgramDesc* program_desc, - framework::ir::Graph* graph, - const std::unordered_set& blacklist, - bool keep_io_types, - phi::Backend backend, - phi::DataType tensor_dtype) { +void ConvertTensorDtype( + framework::ProgramDesc* program_desc, + framework::ir::Graph* graph, + const std::unordered_set& blacklist, + bool keep_io_types, + phi::Backend backend, + phi::DataType tensor_dtype, + bool is_main_block, + std::unordered_map* + vars_in_multi_block_map) { framework::proto::VarType::Type to_type; if (tensor_dtype == phi::DataType::FLOAT16) { to_type = framework::proto::VarType::FP16; @@ -287,25 +451,27 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, to_type = framework::proto::VarType::BF16; } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "mixed_precision currently not supported dtype %d, we now only support " + "mixed_precision currently not supported dtype %d, we now only " + "support " "fp16 and bf16.", static_cast(tensor_dtype))); } - auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc); + auto* block_desc = + framework::ir::TopologySortOperations(*graph)[0]->Op()->Block(); + int num_low_precision = 0; int suffix = 0; - framework::BlockDesc* block_desc{nullptr}; std::vector output_nodes; std::unordered_map cast_map; auto op_nodes = framework::ir::TopologySortOperations(*graph); for (auto* op_node : op_nodes) { if (!op_node->IsOp()) continue; auto op_type = op_node->Op()->Type(); - auto phi_op_type = phi::TransToPhiKernelName(op_type); + VLOG(3) << "-------------------- op_type " << op_type << ", phi_type " + << phi::TransToPhiKernelName(op_type); // 1. set input dtype. if (op_type == "feed") { - block_desc = op_node->Op()->Block(); auto feed_var = op_node->outputs[0]->Var(); if (!keep_io_types && feed_var->GetDataType() == framework::proto::VarType::FP32) { @@ -319,71 +485,73 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, continue; } + else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT + // sub_block op's output dtype should be same as input dtype, if have the + // same name. + std::unordered_map in_name_to_node; + for (auto* in : op_node->inputs) { + if (NodeVarHasDtype(in)) { + in_name_to_node[in->Name()] = in; + } + } + + for (auto out : op_node->outputs) { + if (NodeVarHasDtype(out)) { + if (in_name_to_node.count(out->Name())) + out->Var()->SetDataType( + in_name_to_node[out->Name()]->Var()->GetDataType()); + } + } + + continue; + } + // 2. if op support fp16/bf16 and not in blacklist. // - cast weight to fp16/bf16. // - add cast op if the input dtype is not fp16/bf16. // - set output dtype. - else if (blacklist.count(phi_op_type) == 0 && // NOLINT - !NextOpIncludesConditionFlowOp(op_node)) { + else if (blacklist.count(op_type) == 0) { // NOLINT bool support_precision = - OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist); - VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type - << " support low precision " << support_precision << ", " + OpSupportPrecision(op_type, backend, tensor_dtype, blacklist); + VLOG(2) << "op_type " << op_type << ", phi_op_type " + << phi::TransToPhiKernelName(op_type) << " support low precision " + << support_precision << ", " << reinterpret_cast(op_node->Op()->Block()); - for (auto in_node : op_node->inputs) { - if (weight_name_in_multi_block.count(in_node->Name())) - support_precision = false; - } - if (support_precision) { HandleSpecialOps(op_node->Op()); ++num_low_precision; auto inputs = op_node->inputs; + // Process inputs. for (auto* in_node : inputs) { - if (in_node->IsCtrlVar()) continue; - auto* in_var = in_node->Var(); - if (in_var->Persistable() && - in_var->GetDataType() == framework::proto::VarType::FP32) { - if (WeightsShouldNotConvert(in_node)) continue; - in_var->SetDataType(to_type); - } else if (!in_var->Persistable() && - IsFloatVarType(in_var->GetDataType()) && - in_var->GetDataType() != to_type) { - AddCastOp(graph, - in_node, - op_node, - in_var->GetDataType(), - to_type, - &suffix, - block_desc, - &cast_map); - } + ProcessInputNode(true, + graph, + in_node, + op_node, + &suffix, + block_desc, + &cast_map, + to_type, + is_main_block, + vars_in_multi_block_map); } + // Process outputs. for (auto* out_node : op_node->outputs) { - if (out_node->IsCtrlVar()) continue; - auto* out_var = out_node->Var(); - if (out_var->GetDataType() == framework::proto::VarType::FP32) { - if (OutShouldNotConvert(out_node)) continue; - out_var->SetDataType(to_type); - } + ProcessOutputNode(out_node, to_type); } } else { auto inputs = op_node->inputs; for (auto* in_node : inputs) { - if (in_node->IsCtrlVar()) continue; - auto* in_var = in_node->Var(); - if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) && - in_var->GetDataType() != framework::proto::VarType::FP32) { - AddCastOp(graph, - in_node, - op_node, - in_var->GetDataType(), - framework::proto::VarType::FP32, - &suffix, - block_desc, - &cast_map); - } + ProcessInputNode(false, + graph, + in_node, + op_node, + &suffix, + block_desc, + &cast_map, + framework::proto::VarType::FP32, + is_main_block, + vars_in_multi_block_map); } } } @@ -409,8 +577,8 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, } } - // 4. if output_op's dtype is not compatible to output dtype, then just insert - // cast. + // 4. if output_op's dtype is not compatible to output dtype, then just + // insert cast. for (auto* node : output_nodes) { if (node->IsCtrlVar()) continue; auto var = node->Var(); @@ -438,22 +606,31 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, } } + if (is_main_block) { + for (auto node : graph->Nodes()) { + if (vars_in_multi_block_map->count(node->Name())) { + vars_in_multi_block_map->at(node->Name()) = node->Var()->GetDataType(); + } + } + } + if (num_low_precision) LOG(INFO) << "--- detected " << num_low_precision << " low precision ops"; } } // namespace -bool OpSupportPrecision(const std::string& phi_op_type, +bool OpSupportPrecision(const std::string& op_type, phi::Backend backend, phi::DataType precision, const std::unordered_set& blacklist) { + auto phi_op_type = phi::TransToPhiKernelName(op_type); bool support_precision = false; - if (blacklist.count(phi_op_type) == 0) { + if (blacklist.count(op_type) == 0) { if (backend == phi::Backend::GPU) - support_precision = GpuKernelSupportPrecision(phi_op_type, precision); + support_precision = GpuKernelSupportPrecision(op_type, precision); else support_precision = - IsKernelSupportPrecision(phi_op_type, backend, precision); + PhiKernelSupportPrecision(phi_op_type, backend, precision); } return support_precision; } @@ -521,102 +698,41 @@ void ConvertToMixedPrecision(const std::string& model_file, framework::Scope scope; auto program_desc = inference::Load(&executor, &scope, model_file, params_file); - auto graph = std::unique_ptr( + auto main_graph = std::unique_ptr( new framework::ir::Graph(*program_desc)); - ConvertAllFp64ToFp32(graph.get()); - ConvertTensorDtype(program_desc.get(), - graph.get(), - black_list, - keep_io_types, - backend, - mixed_precision); - FixCastAttr(graph.get()); - - framework::ProgramDesc mixed_program_desc; - framework::ir::GraphToProgram(*graph, &mixed_program_desc); - - auto parameters = scope.LocalVarNames(); - std::sort(parameters.begin(), parameters.end()); - - auto serialize_params = - [](framework::Scope* scope, - const std::vector& params) -> std::string { - std::ostringstream os; - phi::CPUContext ctx; - for (const auto& param : params) { - VLOG(3) << "Serialize param: " << param; - PADDLE_ENFORCE_NOT_NULL( - scope->FindVar(param), - platform::errors::NotFound( - "Block should already have a '%s' variable", param)); - auto* tensor = scope->FindVar(param)->GetMutable(); - framework::SerializeToStream(os, *tensor, ctx); - } - return os.str(); - }; - - std::unordered_set weights_should_be_fp32; - for (auto* node : graph->Nodes()) { - if (!(node->IsVar() && !node->IsCtrlVar())) continue; - if (node->Var()->GetType() == - paddle::framework::proto::VarType::SELECTED_ROWS || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR || - node->Var()->GetType() == - paddle::framework::proto::VarType::LOD_TENSOR_ARRAY || - node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS || - node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB) { - if (node->Var()->Persistable() && - node->Var()->GetDataType() == - paddle::framework::proto::VarType::FP32) { - VLOG(2) << "weights keep to fp32: " << node->Name(); - weights_should_be_fp32.insert(node->Name()); - } - } - } - - for (const auto& param_name : parameters) { - auto* var = scope.FindLocalVar(param_name); - if (var->IsType() || - var->IsType()) { - auto* t = var->GetMutable(); - framework::Tensor mixed_tensor; - mixed_tensor.Resize(t->dims()); - auto* data = t->mutable_data(platform::CPUPlace()); - - if (mixed_precision == phi::DataType::FLOAT16 && - !weights_should_be_fp32.count(param_name)) { - mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16); - auto* mixed_data = - mixed_tensor.mutable_data(platform::CPUPlace()); - for (int i = 0; i < t->numel(); i++) { - mixed_data[i] = static_cast(data[i]); - } - t->clear(); - paddle::framework::TensorCopySync(mixed_tensor, place, t); - } else if (mixed_precision == phi::DataType::BFLOAT16 && - !weights_should_be_fp32.count(param_name)) { - mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16); - auto* mixed_data = - mixed_tensor.mutable_data(platform::CPUPlace()); - for (int i = 0; i < t->numel(); i++) { - mixed_data[i] = static_cast(data[i]); - } - t->clear(); - paddle::framework::TensorCopySync(mixed_tensor, place, t); - } - } + std::unordered_map + vars_in_multi_block_map; + FindVarsInMultiBlock(program_desc.get(), &vars_in_multi_block_map); + + for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) { + auto graph = main_graph->GetSubGraph(i); + VLOG(2) << " -------- handle subgraph " << i << ", has " + << graph->Nodes().size() << " nodes"; + + program_desc->Block(i).LocalVarNames(); + + ConvertAllFp64ToFp32(graph); + ConvertTensorDtype(program_desc.get(), + graph, + black_list, + keep_io_types, + backend, + mixed_precision, + i == 0, + &vars_in_multi_block_map); + FixCastAttr(graph); } - auto StrToBinary = [](const std::string& path, const std::string& str) { - std::ofstream file(path.c_str(), std::ios::binary); - file.write(str.c_str(), str.size()); - file.close(); - }; - StrToBinary(mixed_model_file, - mixed_program_desc.Proto()->SerializeAsString()); - StrToBinary(mixed_params_file, serialize_params(&scope, parameters)); + framework::ProgramDesc mixed_program_desc; + framework::ir::GraphToProgram(*main_graph, &mixed_program_desc); + + SaveMixedModel(main_graph.get(), + &scope, + &mixed_program_desc, + mixed_model_file, + mixed_params_file, + mixed_precision); } } // namespace analysis diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index ae90618f5207cd8100bc5460e63e9c796a2dc3ba..742ce01e8458c1a11635260d581c52fb06e5ca9e 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -410,6 +410,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { pass_builder_->DeletePass(ps); } } + + for (auto &delete_pass : other.pass_builder()->GetAllDeletedPasses()) { + pass_builder_->DeletePass(delete_pass); + } } void AnalysisConfig::EnableCUDNN() {