未验证 提交 d373f4ff 编写于 作者: W Wilber 提交者: GitHub

fix some convert error found in tipc. (#44457)

* fix some error found in tipc.

* update
上级 37455714
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
...@@ -22,6 +23,7 @@ ...@@ -22,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
...@@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) { ...@@ -89,6 +91,31 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
// Get weight names which appear in multiple block (block 0 and block n).
std::unordered_set<std::string> GetMultiBlockPersistableNames(
framework::ProgramDesc* program_desc) {
std::unordered_set<std::string> special_weights;
size_t block_size = program_desc->Size();
std::unordered_set<std::string> block_0_weights;
for (auto var : program_desc->Block(0).AllVars()) {
if (var->Persistable()) block_0_weights.insert(var->Name());
}
for (size_t i = 1; i < block_size; ++i) {
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto all_ops = program_desc->Block(i).AllOps();
for (auto op : all_ops) {
for (auto name : op->InputArgumentNames()) {
if (block_0_weights.count(name)) special_weights.insert(name);
}
}
}
return special_weights;
}
// Just process special cases for weights conversion. // Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node) { bool WeightsShouldNotConvert(ir::Node* var_node) {
auto op_nodes = var_node->outputs; auto op_nodes = var_node->outputs;
...@@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { ...@@ -116,19 +143,139 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
} }
} }
// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for (auto* op_node : op_nodes) {
for (auto var : op_node->outputs) {
for (auto next_op : var->outputs) {
if (next_op->Op()->HasAttr("sub_block")) {
return true;
}
}
}
}
return false; return false;
} }
inline bool IsFloatVarType(framework::proto::VarType::Type type) { inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 || if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 || type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16 || type == framework::proto::VarType::BF16)
type == framework::proto::VarType::FP64)
return true; return true;
return false; return false;
} }
void ConvertTensorDtype(framework::ir::Graph* graph, void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64))
op_node->Op()->SetAttr(
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
}
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
in_var->SetDataType(framework::proto::VarType::FP32);
}
}
}
}
// Handle special ops which contains dtype attribute. e.g., fill_constant,
// assign_value.
void HandleSpecialOps(framework::OpDesc* op_desc) {
if (op_desc->Type() == "fill_constant") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "assign_value") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "eye") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
}
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}
// If op's output var is condition flow op's input, then the op must be fp32
// precision.
bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) {
auto cur_op_outs = cur_op_node->outputs;
for (auto out_var : cur_op_outs) {
for (auto next_op_node : out_var->outputs) {
if (next_op_node->Op()->HasAttr("sub_block")) {
return true;
}
}
}
return false;
}
void ConvertTensorDtype(framework::ProgramDesc* program_desc,
framework::ir::Graph* graph,
const std::unordered_set<std::string>& blacklist, const std::unordered_set<std::string>& blacklist,
bool keep_io_types, bool keep_io_types,
phi::Backend backend, phi::Backend backend,
...@@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -145,13 +292,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
static_cast<int>(tensor_dtype))); static_cast<int>(tensor_dtype)));
} }
auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc);
int num_low_precision = 0; int num_low_precision = 0;
int suffix = 0; int suffix = 0;
framework::BlockDesc* block_desc{nullptr}; framework::BlockDesc* block_desc{nullptr};
std::vector<framework::ir::Node*> output_nodes; std::vector<framework::ir::Node*> output_nodes;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map; std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map;
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : framework::ir::TopologySortOperations(*graph)) { for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type); auto phi_op_type = phi::TransToPhiKernelName(op_type);
...@@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -167,18 +315,29 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
auto* fetch_var = op_node->inputs[0]; auto* fetch_var = op_node->inputs[0];
output_nodes.push_back(fetch_var); output_nodes.push_back(fetch_var);
continue; continue;
} else if (op_type == "cast") {
continue;
} }
// 2. if op support fp16/bf16 and not in blacklist. // 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16. // - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16. // - add cast op if the input dtype is not fp16/bf16.
// - set output dtype. // - set output dtype.
else if (blacklist.count(phi_op_type) == 0) { // NOLINT else if (blacklist.count(phi_op_type) == 0 && // NOLINT
!NextOpIncludesConditionFlowOp(op_node)) {
bool support_precision = bool support_precision =
OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist); OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "phi_op_type " << phi_op_type << " support low precision " VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type
<< support_precision; << " support low precision " << support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block());
for (auto in_node : op_node->inputs) {
if (weight_name_in_multi_block.count(in_node->Name()))
support_precision = false;
}
if (support_precision) { if (support_precision) {
HandleSpecialOps(op_node->Op());
++num_low_precision; ++num_low_precision;
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
...@@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -232,9 +391,8 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 3. check op not support fp16/bf16 or in blacklist. // 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32. // - add cast op if the input dtype is not fp32.
else { // NOLINT else { // NOLINT
// trt pass should explicitle add cast op is input is bf16/tf32, etc. auto ins = op_node->inputs;
if (op_node->Name() == "tensorrt_engine") continue; for (auto* in_node : ins) {
for (auto* in_node : op_node->inputs) {
if (in_node->IsCtrlVar()) continue; if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) { if (in_var->GetDataType() == to_type) {
...@@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -366,8 +524,14 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto graph = std::unique_ptr<framework::ir::Graph>( auto graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc)); new framework::ir::Graph(*program_desc));
ConvertTensorDtype( ConvertAllFp64ToFp32(graph.get());
graph.get(), black_list, keep_io_types, backend, mixed_precision); ConvertTensorDtype(program_desc.get(),
graph.get(),
black_list,
keep_io_types,
backend,
mixed_precision);
FixCastAttr(graph.get());
framework::ProgramDesc mixed_program_desc; framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*graph, &mixed_program_desc); framework::ir::GraphToProgram(*graph, &mixed_program_desc);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册