未验证 提交 0972d6ac 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] improve convert_to_mixed_precision (#47333)

上级 5429d145
......@@ -42,13 +42,13 @@
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using namespace paddle::framework; // NOLINT
namespace paddle {
namespace inference {
namespace analysis {
namespace {
using VarType = framework::proto::VarType;
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
......@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) {
kern_pair.first.data_type_ == VarType::FP16) {
res = true;
break;
}
}
}
......@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
}
class ConvertToMixedPrecisionPass {
using BlockID = size_t;
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
......@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list)
const std::unordered_set<std::string>& black_list)
: model_file_(model_file),
params_file_(params_file),
mixed_model_file_(mixed_model_file),
......@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
keep_io_types_(keep_io_types),
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
executor_(place_) {}
void Run();
private:
void LoadAndPrepare();
inline bool NodeVarHasDtype(framework::ir::Node* node);
inline bool VarNodeHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel();
void ConvertTensorDtype(int block_idx);
void ConvertTensorDtype(BlockID block_idx);
void ProcessInputNode(bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
framework::ir::Node* in_node,
framework::ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx);
VarType::Type to_type,
BlockID block_idx);
void ProcessOutputNode(int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type);
inline bool IsFloatVarType(framework::proto::VarType::Type type);
void ProcessOutputNode(BlockID block_idx,
framework::ir::Node* var_node,
VarType::Type to_type);
inline bool IsFloatVarType(VarType::Type type);
bool OutShouldNotConvert(ir::Node* var_node);
bool OutShouldNotConvert(framework::ir::Node* var_node);
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node);
bool WeightsShouldNotConvert(framework::ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node);
framework::ir::Node* GetRealVarNode(BlockID block_idx,
framework::ir::Node* node);
void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(int block_idx,
inline bool VarIsMultiPrecisionOpsOut(BlockID block_idx,
framework::ir::Node* op_node);
private:
......@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
framework::Scope scope_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map_;
std::vector<std::unordered_map<std::string, std::vector<std::string>>>
vars_appear_multi_in_one_block_;
std::unordered_map<std::string, std::pair<VarType::Type, BlockID>>
vars_in_multi_block_with_pair_;
std::unordered_map<std::string, std::vector<std::string>>
vars_in_multi_block_with_ops_;
int suffix_{0};
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
......@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
std::vector<framework::ir::Graph*> graphes_;
};
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode(
int block_idx, framework::ir::Node* node) {
if (vars_in_multi_block_map_.count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes_[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode(
BlockID block_idx, framework::ir::Node* var_node) {
CHECK_EQ(var_node->IsVar(), true);
if (vars_in_multi_block_with_pair_.count(var_node->Name())) {
auto origin_blockId =
vars_in_multi_block_with_pair_.at(var_node->Name()).second;
if (block_idx != origin_blockId) {
auto* graph = graphes_[origin_blockId];
for (auto* node : graph->Nodes()) {
if (node->Name() == var_node->Name()) {
return node;
}
}
}
}
return node;
return var_node;
}
inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype(
framework::ir::Node* node) {
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
framework::ir::Node* var_node) {
CHECK_EQ(var_node->IsVar(), true);
auto type = var_node->Var()->GetType();
return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
(type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
(type == VarType::VOCAB);
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
int block_idx, framework::ir::Node* op_node) {
BlockID block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true);
bool ret{false};
for (auto* out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block_[block_idx].count(out->Name())) {
for (auto op_type :
vars_appear_multi_in_one_block_[block_idx].at(out->Name())) {
if (OpSupportPrecision(
for (auto* var_node : op_node->outputs) {
if (!var_node->IsVar()) continue;
auto* real_var_node = GetRealVarNode(block_idx, var_node);
if (!real_var_node->Var()->Persistable() &&
vars_in_multi_block_with_ops_.count(var_node->Name())) {
for (const auto& op_type :
vars_in_multi_block_with_ops_.at(var_node->Name())) {
if (!OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) {
ret = true;
VLOG(2) << out->Name()
VLOG(2) << var_node->Name()
<< " is multi precision op's out, so we skip convert to fp16";
break;
return true;
}
}
}
if (ret) break;
}
return ret;
return false;
}
void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
framework::ir::Node* in_node,
framework::ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx) {
auto* real_node = GetRealNode(block_idx, in_node);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes_[block_idx];
VarType::Type to_type,
BlockID block_idx) {
if (!in_node->IsVar()) return;
auto* real_node = GetRealVarNode(block_idx, in_node);
if (!VarNodeHasDtype(real_node)) return;
auto* graph = graphes_[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name());
bool is_in_multi_block = vars_in_multi_block_with_pair_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first;
in_var_type = vars_in_multi_block_with_pair_.at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (in_var->Persistable() && in_var_type == VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
......@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
}
void ConvertToMixedPrecisionPass::ProcessOutputNode(
int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type) {
auto* real_node = GetRealNode(block_idx, var_node);
if (!NodeVarHasDtype(real_node)) return;
BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) {
if (!var_node->IsVar()) return;
auto* real_node = GetRealVarNode(block_idx, var_node);
if (!VarNodeHasDtype(real_node)) return;
auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (out_var->GetDataType() == VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
......@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
}
// Just process special cases.
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(
framework::ir::Node* var_node) {
auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op();
......@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
return false;
}
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(
framework::ir::Node* var_node) {
auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op();
......@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
return false;
}
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(
framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16)
return true;
return false;
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) {
return (type == VarType::FP16) || (type == VarType::FP32) ||
(type == VarType::BF16);
}
void ConvertToMixedPrecisionPass::LoadAndPrepare() {
......@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_));
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto* graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
}
// Remove all control var
IrInferCleanGraphPass pass;
......@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg);
vars_appear_multi_in_one_block_.resize(program_desc_->Size());
FindVarsInMultiBlock();
}
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
for (size_t i = 0; i < program_desc_->Size(); ++i) {
for (auto op : program_desc_->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
std::unordered_set<std::string> all_var_names_set;
std::vector<std::unordered_set<std::string>> block_var_names_set(
program_desc_->Size());
for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) {
for (auto* op : program_desc_->Block(idx).AllOps()) {
const auto& in_names = op->InputArgumentNames();
block_var_names_set[idx].insert(in_names.begin(), in_names.end());
const auto& out_names = op->OutputArgumentNames();
block_var_names_set[idx].insert(out_names.begin(), out_names.end());
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
vars_appear_multi_in_one_block_[i][n].push_back(op->Type());
for (const auto& name : out_names) {
if (all_var_names_set.count(name)) {
vars_in_multi_block_with_ops_[name].push_back(op->Type());
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
all_var_names_set.insert(block_var_names_set[idx].begin(),
block_var_names_set[idx].end());
}
}
for (size_t i = 0; i < program_desc_->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc_->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
CHECK_GT(program_desc_->Size(), 0U);
for (BlockID idx = 0; idx < program_desc_->Size() - 1; ++idx) {
for (BlockID jdx = idx + 1; jdx < program_desc_->Size(); ++jdx) {
std::vector<std::string> vars_in_multi_block;
std::set_intersection(block_var_names_set[idx].begin(),
block_var_names_set[idx].end(),
block_var_names_set[jdx].begin(),
block_var_names_set[jdx].end(),
std::back_inserter(vars_in_multi_block));
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map_.emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
for (const auto& name : vars_in_multi_block) {
vars_in_multi_block_with_pair_.emplace(
name, std::make_pair(VarType::FP32, idx));
}
}
}
......@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("in_dtype", static_cast<int>(VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("out_dtype", static_cast<int>(VarType::FP32));
}
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
in_var->SetDataType(framework::proto::VarType::FP32);
if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) {
in_var->SetDataType(VarType::FP32);
}
}
}
......@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare();
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
for (size_t i = 0; i < graphes_.size(); ++i) {
auto* graph = graphes_[i];
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes --------";
......@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
// A trick
PatchForStrangeOp();
CHECK_EQ(ir::VarDescIsConsistency(*graph), true);
CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true);
}
SaveMixedModel();
}
void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
auto graph = graphes_[block_idx];
framework::proto::VarType::Type to_type;
void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
auto* graph = graphes_[block_idx];
VarType::Type to_type;
if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
to_type = VarType::FP16;
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16;
to_type = VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
......@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 1. set input dtype.
if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types_ &&
feed_var->GetDataType() == framework::proto::VarType::FP32) {
if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) {
feed_var->SetDataType(to_type);
}
} else if (op_type == "fetch") {
......@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in);
if (NodeVarHasDtype(real_node)) {
if (!in->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in);
if (VarNodeHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (NodeVarHasDtype(real_node)) {
for (auto* out : op_node->outputs) {
if (!out->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out);
if (VarNodeHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
......@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
//
// If a var(op's out var) appears multiple times in a block, we should not
// If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16.
else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// if op not has float input, we will not choose the low precision kernel.
// If the op has no input and output of float type, we will not choose the
// low precision kernel.
{
bool has_float_input{false};
for (auto in_node : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in_node);
if (real_node->Var()->GetDataType() == proto::VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 ||
real_node->Var()->GetDataType() == proto::VarType::BF16) {
has_float_input = true;
bool has_float_input_and_output{false};
for (auto* in_node : op_node->inputs) {
if (!in_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break;
}
}
if (!has_float_input) {
for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break;
}
}
if (!has_float_input_and_output) {
support_precision = false;
VLOG(2) << " op doesn't has float input, just skip.";
VLOG(2) << " op doesn't has float input and output, just skip.";
}
}
VLOG(2) << " support low precision " << support_precision;
VLOG(2) << "op type: " << op_type
<< " support low precision: " << support_precision;
if (support_precision) {
VLOG(2) << " process input nodes:";
......@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// Just for paddle's terriable case: op's input and output has the same
// name.
std::unordered_map<std::string, std::string> names_map;
for (auto out_node : op_node->outputs) {
for (auto in_node : op_node->inputs) {
for (auto* out_node : op_node->outputs) {
for (auto* in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name();
}
......@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
op_node,
&suffix_,
block_desc,
framework::proto::VarType::FP32,
VarType::FP32,
block_idx);
}
}
......@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp32.
else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type;
auto ins = op_node->inputs;
for (auto* in_node : ins) {
for (auto* in_node : op_node->inputs) {
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
in_node,
op_node,
to_type,
framework::proto::VarType::FP32,
VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "("
<< framework::proto::VarType::FP32 << ")";
<< cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")";
}
}
}
......@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
ir::Node* fetch_op{nullptr};
framework::ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node;
}
}
CHECK_NOTNULL(fetch_op);
auto var = node->Var();
auto* var = node->Var();
if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
AddCastOp(graph,
node,
fetch_op,
to_type,
framework::proto::VarType::FP32,
VarType::FP32,
&suffix_,
block_desc,
&cast_map_);
} else if (!keep_io_types_ &&
var->GetDataType() == framework::proto::VarType::FP32) {
} else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) {
// fp32 -> fp16/bf16
AddCastOp(graph,
node,
fetch_op,
framework::proto::VarType::FP32,
VarType::FP32,
to_type,
&suffix_,
block_desc,
......@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
}
}
for (auto node : graph->Nodes()) {
auto* real_node = GetRealNode(block_idx, node);
if (!NodeVarHasDtype(real_node)) continue;
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, node);
if (!VarNodeHasDtype(real_node)) continue;
if (vars_in_multi_block_map_.count(real_node->Name()) &&
vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map_.at(real_node->Name()).first =
if (vars_in_multi_block_with_pair_.count(real_node->Name()) &&
vars_in_multi_block_with_pair_.at(real_node->Name()).second ==
block_idx) {
vars_in_multi_block_with_pair_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
......@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
paddle::CPUPlace place;
auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) {
if (!(node->IsVar())) continue;
if (NodeVarHasDtype(node)) {
if (!node->IsVar()) continue;
if (VarNodeHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
node->Var()->GetDataType() == VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
......@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \
for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(origin_data[i]); \
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
origin_tensor->clear(); \
paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for (const auto& param_name : parameters) {
if (weights_should_be_fp32.count(param_name)) continue;
auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) {
auto* t = var->GetMutable<phi::DenseTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue;
auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.Resize(origin_tensor->dims());
auto* origin_data =
origin_tensor->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
......@@ -851,8 +852,8 @@ void AddCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
framework::ir::Node* next_op,
framework::proto::VarType::Type from_type,
framework::proto::VarType::Type to_type,
VarType::Type from_type,
VarType::Type to_type,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) {
......@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
return support_precision;
}
void ConvertToMixedPrecision(const std::string& model_file,
void ConvertToMixedPrecision(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
const std::unordered_set<std::string>& black_list) {
ConvertToMixedPrecisionPass pass(model_file,
params_file,
mixed_model_file,
......
......@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types = true,
std::unordered_set<std::string> black_list = {});
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
} // namespace analysis
} // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册