diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 588a304108f78721d0f3a8d2f3ebcbf01b473618..016df40c86a2da4d7e3ed064abcac35c82198374 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -36,6 +36,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/platform/variant.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace inference { @@ -328,6 +329,9 @@ struct Argument { DECL_ARGUMENT_FIELD(use_npu, UseNpu, bool); DECL_ARGUMENT_FIELD(npu_device_id, NPUDeviceId, int); + // mixed precision related + DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int); + private: std::unordered_set valid_fields_; }; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 49878884ac6276034f1c77eb0c144ed1fc54710f..4aeaefa3c49c3ff5ed715803b879123f799fb2b9 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -86,6 +86,8 @@ void IRPassManager::CreatePasses(Argument *argument, argument->tensorrt_tuned_dynamic_shape(); pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); + pass->Set("model_precision", new int(argument->model_precision())); + if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); std::string dot_file_path; diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index 17bb8b6c62ab7b33c9cf7ac63b4cc3c6a97d1f36..126e2500c4890007cd8f3d579f522df5496b052b 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -10,6 +10,10 @@ cc_library( memory_optim_pass SRCS memory_optimize_pass.cc DEPS analysis_pass zero_copy_tensor) +cc_library( + convert_to_mixed_precision + SRCS convert_to_mixed_precision.cc + DEPS analysis_pass ir_graph_build_pass) cc_library( ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc @@ -46,6 +50,7 @@ cc_library( ir_params_sync_among_devices_pass adjust_cudnn_workspace_size_pass memory_optim_pass + convert_to_mixed_precision inference_op_replace_pass ir_graph_to_program_pass ir_graph_clean_pass) diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b6651678f85e0c0e85c8fad6bc965bfbbc8c782 --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -0,0 +1,452 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" + +#include + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/core/tensor_meta.h" + +using namespace paddle::framework; // NOLINT + +namespace paddle { +namespace inference { +namespace analysis { + +namespace { + +bool IsKernelSupportPrecision( + const std::string& op_type, + phi::Backend backend, + phi::DataType data_type, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + auto kernels = phi::KernelFactory::Instance().kernels(); + if (kernels.find(op_type) == kernels.end()) { + return false; + } + phi::KernelKey kernel_key(backend, layout, data_type); + return phi::KernelFactory::Instance().HasKernel(op_type, kernel_key); +} + +bool GpuKernelSupportPrecision( + const std::string& op_type, + phi::DataType data_type, + phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { + bool res = + IsKernelSupportPrecision(op_type, phi::Backend::GPU, data_type, layout); + res |= IsKernelSupportPrecision( + op_type, phi::Backend::GPUDNN, data_type, layout); + return res; +} + +// Just process special cases. +bool OutShouldNotConvert(ir::Node* var_node) { + auto op_node = var_node->inputs[0]; + auto* op_desc = op_node->Op(); + + // batch_norm's input and output (variance and mean) are the same. + if (op_desc->Type() == "batch_norm") { + auto vecs = op_desc->Output("MeanOut"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Output("VarianceOut"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedMean"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Output("SavedVariance"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + } + + return false; +} + +// Just process special cases for weights conversion. +bool WeightsShouldNotConvert(ir::Node* var_node) { + auto op_nodes = var_node->outputs; + for (auto* op_node : op_nodes) { + auto* op_desc = op_node->Op(); + // batch_norm op's bias, mean, scale and variance just be float32, so we can + // not convert the dtype. + if (op_desc->Type() == "batch_norm") { + auto vecs = op_desc->Input("Bias"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Mean"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + vecs = op_desc->Input("Variance"); + if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) { + return true; + } + } + } + + return false; +} + +void ConvertTensorDtype(framework::ir::Graph* graph, + const std::unordered_set& blacklist, + bool keep_io_types, + phi::Backend backend, + phi::DataType tensor_dtype) { + framework::proto::VarType::Type to_type; + if (tensor_dtype == phi::DataType::FLOAT16) { + to_type = framework::proto::VarType::FP16; + } else if (tensor_dtype == phi::DataType::BFLOAT16) { + to_type = framework::proto::VarType::BF16; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported dtype %d, we now only support " + "fp16 and bf16.", + static_cast(tensor_dtype))); + } + + int num_low_precision = 0; + int suffix = 0; + framework::BlockDesc* block_desc{nullptr}; + std::vector output_nodes; + std::unordered_map cast_map; + + for (auto* op_node : framework::ir::TopologySortOperations(*graph)) { + if (!op_node->IsOp()) continue; + auto op_type = op_node->Op()->Type(); + auto phi_op_type = phi::TransToPhiKernelName(op_type); + // LOG(INFO) << "process op " << op_type << ", corresponding phi type is " + // << phi_op_type; + // 1. set input dtype. + if (op_type == "feed") { + block_desc = op_node->Op()->Block(); + auto feed_var = op_node->outputs[0]->Var(); + if (!keep_io_types && + feed_var->GetDataType() == framework::proto::VarType::FP32) { + feed_var->SetDataType(to_type); + } + } else if (op_type == "fetch") { + auto* fetch_var = op_node->inputs[0]; + output_nodes.push_back(fetch_var); + continue; + } + + // 2. if op support fp16/bf16 and not in blacklist. + // - cast weight to fp16/bf16. + // - add cast op if the input dtype is not fp16/bf16. + // - set output dtype. + else if (blacklist.count(phi_op_type) == 0) { // NOLINT + bool support_precision = + OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist); + VLOG(2) << "phi_op_type " << phi_op_type << " support low precision " + << support_precision; + if (support_precision) { + ++num_low_precision; + auto inputs = op_node->inputs; + for (auto* in_node : inputs) { + auto* in_var = in_node->Var(); + if (in_var->Persistable() && + in_var->GetDataType() == framework::proto::VarType::FP32) { + if (WeightsShouldNotConvert(in_node)) continue; + in_var->SetDataType(to_type); + } else if (!in_var->Persistable() && + in_var->GetDataType() != to_type) { + AddCastOp(graph, + in_node, + op_node, + in_var->GetDataType(), + to_type, + &suffix, + block_desc, + &cast_map); + } + } + for (auto* out_node : op_node->outputs) { + auto* out_var = out_node->Var(); + if (out_var->GetDataType() == framework::proto::VarType::FP32) { + if (OutShouldNotConvert(out_node)) continue; + out_var->SetDataType(to_type); + } + } + } else { + auto inputs = op_node->inputs; + for (auto* in_node : inputs) { + auto* in_var = in_node->Var(); + if (!in_var->Persistable() && + in_var->GetDataType() != framework::proto::VarType::FP32) { + AddCastOp(graph, + in_node, + op_node, + in_var->GetDataType(), + framework::proto::VarType::FP32, + &suffix, + block_desc, + &cast_map); + } + } + } + } + + // 3. check op not support fp16/bf16 or in blacklist. + // - add cast op if the input dtype is not fp32. + else { // NOLINT + // trt pass should explicitle add cast op is input is bf16/tf32, etc. + if (op_node->Name() == "tensorrt_engine") continue; + for (auto* in_node : op_node->inputs) { + auto* in_var = in_node->Var(); + if (in_var->GetDataType() == to_type) { + AddCastOp(graph, + in_node, + op_node, + to_type, + framework::proto::VarType::FP32, + &suffix, + block_desc, + &cast_map); + } + } + } + } + + // 4. if output_op's dtype is not compatible to output dtype, then just insert + // cast. + for (auto* node : output_nodes) { + auto var = node->Var(); + if (keep_io_types && var->GetDataType() == to_type) { + // fp16/bf16 -> fp32. + AddCastOp(graph, + node, + node->outputs[0], + to_type, + framework::proto::VarType::FP32, + &suffix, + block_desc, + &cast_map); + } else if (!keep_io_types && + var->GetDataType() == framework::proto::VarType::FP32) { + // fp32 -> fp16/bf16 + AddCastOp(graph, + node, + node->outputs[0], + framework::proto::VarType::FP32, + to_type, + &suffix, + block_desc, + &cast_map); + } + } + + if (num_low_precision) + LOG(INFO) << "--- detected " << num_low_precision << " low precision ops"; +} +} // namespace + +bool OpSupportPrecision(const std::string& phi_op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& blacklist) { + bool support_precision = false; + if (blacklist.count(phi_op_type) == 0) { + if (backend == phi::Backend::GPU) + support_precision = GpuKernelSupportPrecision(phi_op_type, precision); + else + support_precision = + IsKernelSupportPrecision(phi_op_type, backend, precision); + } + return support_precision; +} + +void AddCastOp( + framework::ir::Graph* graph, + framework::ir::Node* node, + framework::ir::Node* next_op, + framework::proto::VarType::Type from_type, + framework::proto::VarType::Type to_type, + int* suffix, + framework::BlockDesc* block_desc, + std::unordered_map* map) { + auto update_cast_desc = [&](framework::OpDesc& desc, + const std::string& x_name, + const std::string& out_name, + const int in_dtype, + const int out_dtype) { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + desc.SetAttr("use_mkldnn", false); + desc.SetAttr("with_quant_attr", false); + desc.Flush(); + }; + + if (map->count(node) == 0) { + // insert cast op before node. + std::string cast_input_name = node->Var()->Name(); + std::string cast_output_name = + node->Var()->Name() + "_cast.tmp_" + std::to_string((*suffix)++); + CHECK_NOTNULL(block_desc); + framework::OpDesc cast_op_desc(block_desc); + update_cast_desc(cast_op_desc, + cast_input_name, + cast_output_name, + static_cast(from_type), + static_cast(to_type)); + auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); + auto* cast_output_vardesc = block_desc->Var(cast_output_name); + cast_output_vardesc->SetPersistable(false); + cast_output_vardesc->SetDataType(to_type); + cast_output_vardesc->SetShape(node->Var()->GetShape()); + auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); + IR_NODE_LINK_TO(cast_op_node, cast_output_node); + (*map)[node] = cast_output_node; + } + next_op->Op()->RenameInput(node->Name(), map->at(node)->Name()); + IR_NODE_LINK_TO(node, map->at(node)->inputs[0]); + IR_NODE_LINK_TO(map->at(node), next_op); +} + +void ConvertToMixedPrecision(const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types, + std::unordered_set black_list) { + paddle::CPUPlace place; + framework::Executor executor(place); + framework::Scope scope; + auto program_desc = + inference::Load(&executor, &scope, model_file, params_file); + auto graph = std::unique_ptr( + new framework::ir::Graph(*program_desc)); + + ConvertTensorDtype( + graph.get(), black_list, keep_io_types, backend, mixed_precision); + + framework::ProgramDesc mixed_program_desc; + framework::ir::GraphToProgram(*graph, &mixed_program_desc); + + auto parameters = scope.LocalVarNames(); + std::sort(parameters.begin(), parameters.end()); + + auto serialize_params = + [](framework::Scope* scope, + const std::vector& params) -> std::string { + std::ostringstream os; + platform::CPUDeviceContext ctx; + for (const auto& param : params) { + VLOG(3) << "Serialize param: " << param; + PADDLE_ENFORCE_NOT_NULL( + scope->FindVar(param), + platform::errors::NotFound( + "Block should already have a '%s' variable", param)); + auto* tensor = scope->FindVar(param)->GetMutable(); + framework::SerializeToStream(os, *tensor, ctx); + } + return os.str(); + }; + + std::unordered_set weights_should_be_fp32; + for (auto* node : paddle::framework::ir::TopologySortOperations(*graph)) { + if (!node->IsOp()) continue; + auto* op_desc = node->Op(); + if (op_desc->Type() == "feed" || op_desc->Type() == "fetch") continue; + + if (op_desc->Type() == "batch_norm") { + auto vecs = op_desc->Input("Bias"); + for (auto s : vecs) { + weights_should_be_fp32.insert(s); + } + vecs = op_desc->Input("Mean"); + for (auto s : vecs) { + weights_should_be_fp32.insert(s); + } + vecs = op_desc->Input("Scale"); + for (auto s : vecs) { + weights_should_be_fp32.insert(s); + } + vecs = op_desc->Input("Variance"); + for (auto s : vecs) { + weights_should_be_fp32.insert(s); + } + } + } + + for (const auto& param_name : parameters) { + auto* var = scope.FindLocalVar(param_name); + if (var->IsType() || + var->IsType()) { + auto* t = var->GetMutable(); + framework::Tensor mixed_tensor; + mixed_tensor.Resize(t->dims()); + auto* data = t->mutable_data(platform::CPUPlace()); + + if (mixed_precision == phi::DataType::FLOAT16 && + !weights_should_be_fp32.count(param_name)) { + mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16); + auto* mixed_data = + mixed_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + mixed_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(mixed_tensor, place, t); + } else if (mixed_precision == phi::DataType::BFLOAT16 && + !weights_should_be_fp32.count(param_name)) { + mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16); + auto* mixed_data = + mixed_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + mixed_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(mixed_tensor, place, t); + } + } + } + + auto StrToBinary = [](const std::string& path, const std::string& str) { + std::ofstream file(path.c_str(), std::ios::binary); + file.write(str.c_str(), str.size()); + file.close(); + }; + StrToBinary(mixed_model_file, + mixed_program_desc.Proto()->SerializeAsString()); + StrToBinary(mixed_params_file, serialize_params(&scope, parameters)); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h new file mode 100644 index 0000000000000000000000000000000000000000..2a19453b02a01182cf763d934f62c3dda16bce5e --- /dev/null +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" + +namespace paddle { +namespace inference { +namespace analysis { + +bool OpSupportPrecision(const std::string& phi_op_type, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set& blacklist); + +void AddCastOp( + framework::ir::Graph* graph, + framework::ir::Node* node, + framework::ir::Node* next_op, + framework::proto::VarType::Type from_type, + framework::proto::VarType::Type to_type, + int* suffix, + framework::BlockDesc* block_desc, + std::unordered_map* map); + +void ConvertToMixedPrecision(const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + phi::DataType mixed_precision, + phi::Backend backend, + bool keep_io_types = true, + std::unordered_set black_list = {}); + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index a785aba4bb40bc6483f84b909ac947543e1a1fc2..bc330354e71fcb724251e3a80183f6d683a02590 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -14,10 +14,16 @@ #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" +#include + #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace inference { @@ -106,34 +112,63 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { if (with_dynamic_shape) { reserve_cpu_weights = true; } - for (auto &var_name : all_vars) { - if (std::count( - repetitive_params.begin(), repetitive_params.end(), var_name)) { - if (!reserve_cpu_weights) { - scope->EraseVars({var_name}); - } - continue; - } - auto *var = scope->FindLocalVar(var_name); - PADDLE_ENFORCE_NOT_NULL( - var, - platform::errors::PreconditionNotMet("The var should not be nullptr")); - if (var->IsType() || - var->IsType()) { - auto *t = var->GetMutable(); - platform::CPUPlace cpu_place; - framework::LoDTensor temp_tensor; - temp_tensor.Resize(t->dims()); - temp_tensor.mutable_data(cpu_place); - - // Copy the parameter data to a tmp tensor. - paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); - // Reallocation the space on GPU - t->clear(); - - // Copy parameter data to newly allocated GPU space. - paddle::framework::TensorCopySync(temp_tensor, place, t); + for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { + if (!node->IsOp()) continue; + if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue; + for (auto *var_node : node->inputs) { + if (!var_node->Var()->Persistable()) continue; + auto var_name = var_node->Var()->Name(); + if (std::count( + repetitive_params.begin(), repetitive_params.end(), var_name)) { + if (!reserve_cpu_weights) { + scope->EraseVars({var_name}); + } + continue; + } + auto *var = scope->FindLocalVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var, + platform::errors::PreconditionNotMet( + "The var should not be nullptr")); + if (var->IsType() || + var->IsType()) { + auto *t = var->GetMutable(); + auto var_data_type = var_node->Var()->GetDataType(); + VLOG(5) << "var_name is " << var_name << ", data type is " + << var_data_type; + if (var_data_type == paddle::framework::proto::VarType::FP16) { + framework::Tensor half_tensor; + half_tensor.set_type(paddle::experimental::DataType::FLOAT16); + half_tensor.Resize(t->dims()); + auto *half_data = + half_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + auto *data = t->mutable_data(platform::CPUPlace()); + half_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(half_tensor, place, t); + } else if (var_data_type == paddle::framework::proto::VarType::BF16) { + framework::Tensor bf16_tensor; + bf16_tensor.set_type(paddle::experimental::DataType::BFLOAT16); + bf16_tensor.Resize(t->dims()); + auto *bf16_data = bf16_tensor.mutable_data( + platform::CPUPlace()); + for (int i = 0; i < t->numel(); i++) { + auto *data = t->mutable_data(platform::CPUPlace()); + bf16_data[i] = static_cast(data[i]); + } + t->clear(); + paddle::framework::TensorCopySync(bf16_tensor, place, t); + } else { + platform::CPUPlace cpu_place; + framework::LoDTensor temp_tensor; + temp_tensor.Resize(t->dims()); + paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor); + t->clear(); + paddle::framework::TensorCopySync(temp_tensor, place, t); + } + } } } } diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index cace195640f642baa8ae4fa50a665877178ef409..0d55b9c66416a4ef024e73d7fd32b37c14f0a4d6 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -82,6 +82,7 @@ if(WITH_ONNXRUNTIME) ir_pass_manager op_compatible_info infer_io_utils + model_utils onnxruntime paddle2onnx) else() @@ -90,7 +91,7 @@ else() SRCS analysis_predictor.cc resource_manager.cc infer_context.cc ${mkldnn_quantizer_src} DEPS ${inference_deps} zero_copy_tensor ir_pass_manager op_compatible_info - infer_io_utils) + infer_io_utils model_utils) endif() cc_test( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7cf49e533f7c5e9cc8e2229c996929fbd643b9f2..7bdd1c957b713dad0d248df76e2cb39f7b8913c1 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -36,12 +36,15 @@ #include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/framework/version.h" #include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/infer_context.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/utils/io_utils.h" +#include "paddle/fluid/inference/utils/model_utils.h" #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/cpu_helper.h" @@ -50,6 +53,8 @@ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/utils/string/split.h" @@ -102,6 +107,43 @@ bool IsPersistable(const framework::VarDesc *var) { } return false; } + +phi::DataType ConvertPrecision(AnalysisConfig::Precision precision) { + switch (precision) { + case AnalysisConfig::Precision::kFloat32: + return phi::DataType::FLOAT32; + case AnalysisConfig::Precision::kHalf: + return phi::DataType::FLOAT16; + case AnalysisConfig::Precision::kBf16: + return phi::DataType::BFLOAT16; + case AnalysisConfig::Precision::kInt8: + return phi::DataType::INT8; + default: + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Paddle Inference not support precision. We now only support " + "Float32, Half, Bfloat16 and Int8")); + return phi::DataType::FLOAT32; + } +} + +phi::Backend ConvertBackend(AnalysisConfig::Backend backend) { + switch (backend) { + case AnalysisConfig::Backend::kGPU: + // NOTE: phi also support phi::Backend::GPUDNN. + return phi::Backend::GPU; + case AnalysisConfig::Backend::kNPU: + return phi::Backend::NPU; + case AnalysisConfig::Backend::kXPU: + return phi::Backend::XPU; + case AnalysisConfig::Backend::kCPU: + return phi::Backend::CPU; + default: + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Paddle Inference not support backend, we now only support GPU, XPU, " + "NPU and CPU.")); + return phi::Backend::CPU; + } +} } // namespace bool PaddleTensorToLoDTensor(const PaddleTensor &pt, @@ -476,6 +518,8 @@ bool AnalysisPredictor::PrepareProgram( // if enable_ir_optim_ is false, // the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will // not be executed. + model_precision_ = + paddle::inference::GetModelPrecision(*inference_program_); OptimizeInferenceProgram(); } else { // If the program is passed from external, no need to optimize it, this @@ -1129,6 +1173,40 @@ void AnalysisPredictor::PrepareArgument() { #endif auto passes = config_.pass_builder()->AllPasses(); + if (model_precision_ != phi::DataType::FLOAT32) { + LOG(INFO) << "Model is mixed precision type with " << model_precision_ + << ", we will use a new PassStrategy. Note that only the GPU " + "backend is supported for now."; + passes.clear(); + if (config_.tensorrt_engine_enabled()) { + for (const auto &pass : kTrtLowerPrecisionPasses) { + passes.push_back(pass); + } + } else if (config_.use_gpu()) { + for (const auto &pass : kGpuLowerPrecisionPasses) { + passes.push_back(pass); + } + } + + const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses(); + for (const auto &it : deleted_passes) { + auto iterator = std::find(passes.begin(), passes.end(), it); + if (iterator != passes.end()) { + passes.erase(iterator); + } + } + + if (config_.ir_debug_) { + auto it = std::begin(passes); + while (it != std::end(passes)) { + if (*it != "graph_viz_pass") { + it = passes.insert(it + 1, "graph_viz_pass"); + } else { + ++it; + } + } + } + } if (!config_.ir_optim()) { passes.clear(); LOG(INFO) << "ir_optim is turned off, no IR pass will be executed"; @@ -1137,6 +1215,8 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetIrAnalysisPasses(passes); argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); argument_.SetScopeNotOwned(scope_.get()); + + argument_.SetModelPrecision(static_cast(model_precision_)); } // NOTE All the members in AnalysisConfig should be copied to Argument. @@ -2112,6 +2192,26 @@ std::string UpdateDllFlag(const char *name, const char *value) { return paddle::UpdateDllFlag(name, value); } +void ConvertToMixedPrecision(const std::string &model_file, + const std::string ¶ms_file, + const std::string &mixed_model_file, + const std::string &mixed_params_file, + PrecisionType mixed_precision, + BackendType backend, + bool keep_io_types, + std::unordered_set black_list) { + auto phi_backend = paddle::ConvertBackend(backend); + auto phi_precision = paddle::ConvertPrecision(mixed_precision); + paddle::inference::analysis::ConvertToMixedPrecision(model_file, + params_file, + mixed_model_file, + mixed_params_file, + phi_precision, + phi_backend, + keep_io_types, + black_list); +} + } // namespace paddle_infer namespace paddle_infer { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 9ebd2a3ab0fa5692397efabc1dcaf3d0b8259fd4..0835f712b6e4607d780c0c18f168d12f8e272f8e 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -18,6 +18,7 @@ #include #include #include +#include "paddle/phi/common/data_type.h" #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #endif @@ -478,6 +479,8 @@ class AnalysisPredictor : public PaddlePredictor { std::vector fetches_; std::map idx2fetches_; + phi::DataType model_precision_{phi::DataType::FLOAT32}; + #if PADDLE_WITH_MKLDNN // Helper class to perform quantization class MkldnnQuantizer; diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 34e18a407eeaf65862251af5148294ca7f8c0105..6de23e930836aff03c544a08ab123724f0277b91 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -167,6 +167,14 @@ struct PD_INFER_DECL AnalysisConfig { kFloat32 = 0, ///< fp32 kInt8, ///< int8 kHalf, ///< fp16 + kBf16, ///< bf16 + }; + + enum class Backend { + kCPU = 0, + kGPU, + kXPU, + kNPU, }; /// diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 3111db026c4e6a906db1430594bff4005a293b69..c3ccb58b8031ce19d04fe01dc1893e56573215fe 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -25,6 +25,7 @@ limitations under the License. */ #include #include #include +#include #include #include @@ -46,6 +47,7 @@ namespace paddle_infer { using PrecisionType = paddle::AnalysisConfig::Precision; using Config = paddle::AnalysisConfig; using DistConfig = paddle::DistConfig; +using BackendType = paddle::AnalysisConfig::Backend; /// /// \class Predictor @@ -183,6 +185,16 @@ PD_INFER_DECL std::tuple GetTrtCompileVersion(); PD_INFER_DECL std::tuple GetTrtRuntimeVersion(); PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value); +PD_INFER_DECL void ConvertToMixedPrecision( + const std::string& model_file, + const std::string& params_file, + const std::string& mixed_model_file, + const std::string& mixed_params_file, + PrecisionType mixed_precision, + BackendType backend, + bool keep_io_types = true, + std::unordered_set black_list = {}); + namespace services { /// /// \class PredictorPool diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 954a6898781c18719890f8a253430d153b1baf69..73c216290dd88e3aa8d59d4ab53172e7e8eff80a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -52,6 +52,7 @@ std::string PaddlePassBuilder::DebugString() { } void PaddlePassBuilder::DeletePass(const std::string &pass_type) { + deleted_passes_.insert(pass_type); auto it = std::begin(passes_); while (it != std::end(passes_)) { if (*it == pass_type) { @@ -149,6 +150,19 @@ const std::vector kLiteSubgraphPasses({ #endif }); +// TODO(inference): Most of the existing pass fusion operators do not +// support fp16/bf16 precision, temporarily use low precision pass to prevent +// running errors. After fusion operator supports low precision, delete this. +const std::vector kGpuLowerPrecisionPasses{ + // "conv_bn_fuse_pass", + // "conv_eltwiseadd_bn_fuse_pass", +}; +const std::vector kTrtLowerPrecisionPasses{ + // "conv_bn_fuse_pass", + // "conv_eltwiseadd_bn_fuse_pass", + "tensorrt_subgraph_pass", +}; + GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 2b6c189cffcf270edb7396900061233a3eff195c..cd97382785395f741df35b787cb5fbfe14e2d182 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle_infer_declare.h" // NOLINT @@ -106,6 +107,10 @@ class PD_INFER_DECL PaddlePassBuilder { return passes; } + const std::unordered_set &GetAllDeletedPasses() const { + return deleted_passes_; + } + protected: /// \cond Protected std::vector analysis_passes_{ @@ -116,6 +121,7 @@ class PD_INFER_DECL PaddlePassBuilder { "adjust_cudnn_workspace_size_pass", "inference_op_replace_pass"}}; std::vector passes_; + std::unordered_set deleted_passes_; /// \endcond }; @@ -177,6 +183,8 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { bool use_ipu_{false}; bool use_mkldnn_{false}; bool use_custom_device_{false}; + + bool use_gpu_low_precision_{false}; /// \endcond }; @@ -328,4 +336,10 @@ PD_INFER_DECL extern const std::vector kDlnneSubgraphPasses; /// \brief List of lite subgraph passes. PD_INFER_DECL extern const std::vector kLiteSubgraphPasses; +/// \brief TODO(inference): Most of the existing pass fusion operators do not +/// support fp16/bf16 precision, temporarily use low precision pass to prevent +/// running errors. After fusion operator supports low precision, delete this. +PD_INFER_DECL extern const std::vector kGpuLowerPrecisionPasses; +PD_INFER_DECL extern const std::vector kTrtLowerPrecisionPasses; + } // namespace paddle diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index a32a61842a5ec5c8c0278bb7455a85bc25daf163..9ab07633e0fe05595417c3399fe41cbada13c140 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -10,6 +10,10 @@ cc_library( infer_io_utils SRCS io_utils.cc DEPS paddle_inference_api lod_tensor shape_range_info_proto) +cc_library( + model_utils + SRCS model_utils.cc + DEPS proto_desc enforce) cc_test( infer_io_utils_tester SRCS io_utils_tester.cc diff --git a/paddle/fluid/inference/utils/model_utils.cc b/paddle/fluid/inference/utils/model_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..27bc8b35306e1bff7f912525073464895080dc45 --- /dev/null +++ b/paddle/fluid/inference/utils/model_utils.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/utils/model_utils.h" +#include +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/phi/common/data_type.h" + +namespace paddle { +namespace inference { + +using paddle::framework::proto::VarType; + +// Get all model's weights and return the data_type, e.g., fp16/bf16 or fp32. +phi::DataType GetModelPrecision(const framework::ProgramDesc& program) { + std::set model_types{ + VarType::FP32, + VarType::FP16, + VarType::BF16, + }; + + phi::DataType ret = phi::DataType::FLOAT32; + size_t block_size = program.Size(); + + for (size_t i = 0; i < block_size; ++i) { + const auto& block = program.Block(i); + for (auto* var : block.AllVars()) { + if (!(var->GetType() == VarType::LOD_TENSOR || + var->GetType() == VarType::LOD_TENSOR_ARRAY)) + continue; + + if (!var->Persistable()) continue; + auto t = var->GetDataType(); + if (!model_types.count(t)) continue; + + if (t == VarType::FP16) { + if (ret != phi::DataType::FLOAT32 && ret != phi::DataType::FLOAT16) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "The model's weights already has been set %s type, but also has " + "%s type, which is an error, please check the model.", + ret, + phi::DataType::FLOAT16)); + } + ret = phi::DataType::FLOAT16; + } else if (t == VarType::BF16) { + if (ret != phi::DataType::FLOAT32 && ret != phi::DataType::BFLOAT16) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "The model's weights already has been set %s type, but also has " + "%s type, which is an error, please check the model.", + ret, + phi::DataType::BFLOAT16)); + } + ret = phi::DataType::BFLOAT16; + } + } + } + + return ret; +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/utils/model_utils.h b/paddle/fluid/inference/utils/model_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..54144b212034cc48a847e5a269f374b36e794911 --- /dev/null +++ b/paddle/fluid/inference/utils/model_utils.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/phi/common/data_type.h" + +namespace paddle { +namespace inference { + +// Get all model's weights and return the data_type, e.g., fp16/bf16 or fp32. +phi::DataType GetModelPrecision(const framework::ProgramDesc& program); + +} // namespace inference +} // namespace paddle