“597cc0589b1ab7a6f5d4d465313d5401e802bc6f”上不存在“...fluid/platform/git@gitcode.net:BaiXuePrincess/Paddle.git”
未验证 提交 0972d6ac 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Paddle Inference] improve convert_to_mixed_precision (#47333)

上级 5429d145
...@@ -42,13 +42,13 @@ ...@@ -42,13 +42,13 @@
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
using namespace paddle::framework; // NOLINT
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
namespace { namespace {
using VarType = framework::proto::VarType;
bool PhiKernelSupportPrecision( bool PhiKernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
phi::Backend backend, phi::Backend backend,
...@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision( ...@@ -73,13 +73,14 @@ bool GpuKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout); phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) { if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto& all_kernels = framework::OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) { if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) { for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_) && if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) { kern_pair.first.data_type_ == VarType::FP16) {
res = true; res = true;
break;
} }
} }
} }
...@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision( ...@@ -88,6 +89,8 @@ bool GpuKernelSupportPrecision(
} }
class ConvertToMixedPrecisionPass { class ConvertToMixedPrecisionPass {
using BlockID = size_t;
public: public:
explicit ConvertToMixedPrecisionPass( explicit ConvertToMixedPrecisionPass(
const std::string& model_file, const std::string& model_file,
...@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass { ...@@ -97,7 +100,7 @@ class ConvertToMixedPrecisionPass {
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types, bool keep_io_types,
std::unordered_set<std::string> black_list) const std::unordered_set<std::string>& black_list)
: model_file_(model_file), : model_file_(model_file),
params_file_(params_file), params_file_(params_file),
mixed_model_file_(mixed_model_file), mixed_model_file_(mixed_model_file),
...@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass { ...@@ -107,45 +110,40 @@ class ConvertToMixedPrecisionPass {
keep_io_types_(keep_io_types), keep_io_types_(keep_io_types),
black_list_(black_list), black_list_(black_list),
place_(paddle::CPUPlace()), place_(paddle::CPUPlace()),
executor_(place_) { executor_(place_) {}
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run(); void Run();
private: private:
void LoadAndPrepare(); void LoadAndPrepare();
inline bool NodeVarHasDtype(framework::ir::Node* node); inline bool VarNodeHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph); void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph); void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel(); void SaveMixedModel();
void ConvertTensorDtype(int block_idx); void ConvertTensorDtype(BlockID block_idx);
void ProcessInputNode(bool support_precision, void ProcessInputNode(bool support_precision,
ir::Node* in_node, framework::ir::Node* in_node,
ir::Node* op_node, framework::ir::Node* op_node,
int* suffix, int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type, VarType::Type to_type,
int block_idx); BlockID block_idx);
void ProcessOutputNode(int block_idx, void ProcessOutputNode(BlockID block_idx,
ir::Node* var_node, framework::ir::Node* var_node,
framework::proto::VarType::Type to_type); VarType::Type to_type);
inline bool IsFloatVarType(framework::proto::VarType::Type type); inline bool IsFloatVarType(VarType::Type type);
bool OutShouldNotConvert(ir::Node* var_node); bool OutShouldNotConvert(framework::ir::Node* var_node);
// Just process special cases for weights conversion. // Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node); bool WeightsShouldNotConvert(framework::ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases. // To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block. // Return Node* which first appers in block.
framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node); framework::ir::Node* GetRealVarNode(BlockID block_idx,
framework::ir::Node* node);
void FindVarsInMultiBlock(); void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(int block_idx, inline bool VarIsMultiPrecisionOpsOut(BlockID block_idx,
framework::ir::Node* op_node); framework::ir::Node* op_node);
private: private:
...@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass { ...@@ -167,11 +165,10 @@ class ConvertToMixedPrecisionPass {
framework::Scope scope_; framework::Scope scope_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_; std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string, std::unordered_map<std::string, std::pair<VarType::Type, BlockID>>
std::pair<framework::proto::VarType::Type, int>> vars_in_multi_block_with_pair_;
vars_in_multi_block_map_; std::unordered_map<std::string, std::vector<std::string>>
std::vector<std::unordered_map<std::string, std::vector<std::string>>> vars_in_multi_block_with_ops_;
vars_appear_multi_in_one_block_;
int suffix_{0}; int suffix_{0};
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr}; std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
...@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass { ...@@ -179,91 +176,84 @@ class ConvertToMixedPrecisionPass {
std::vector<framework::ir::Graph*> graphes_; std::vector<framework::ir::Graph*> graphes_;
}; };
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode( framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode(
int block_idx, framework::ir::Node* node) { BlockID block_idx, framework::ir::Node* var_node) {
if (vars_in_multi_block_map_.count(node->Name())) { CHECK_EQ(var_node->IsVar(), true);
int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second;
if (block_idx != var_origin_block_id) { if (vars_in_multi_block_with_pair_.count(var_node->Name())) {
auto graph = graphes_[var_origin_block_id]; auto origin_blockId =
for (auto nd : graph->Nodes()) { vars_in_multi_block_with_pair_.at(var_node->Name()).second;
if (nd->Name() == node->Name()) { if (block_idx != origin_blockId) {
return nd; auto* graph = graphes_[origin_blockId];
for (auto* node : graph->Nodes()) {
if (node->Name() == var_node->Name()) {
return node;
} }
} }
} }
} }
return node; return var_node;
} }
inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype( inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
framework::ir::Node* node) { framework::ir::Node* var_node) {
if (node->IsVar() && CHECK_EQ(var_node->IsVar(), true);
(node->Var()->GetType() == auto type = var_node->Var()->GetType();
paddle::framework::proto::VarType::SELECTED_ROWS || return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) ||
node->Var()->GetType() == (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) ||
paddle::framework::proto::VarType::LOD_TENSOR || (type == VarType::VOCAB);
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
} }
// op1(fp32) -> var1, op2(fp16) -> var1 // op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's // if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision. // precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut( inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
int block_idx, framework::ir::Node* op_node) { BlockID block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true); CHECK_EQ(op_node->IsOp(), true);
bool ret{false};
for (auto* var_node : op_node->outputs) {
for (auto* out : op_node->outputs) { if (!var_node->IsVar()) continue;
auto* real_node = GetRealNode(block_idx, out); auto* real_var_node = GetRealVarNode(block_idx, var_node);
if (!real_node->Var()->Persistable() && if (!real_var_node->Var()->Persistable() &&
vars_appear_multi_in_one_block_[block_idx].count(out->Name())) { vars_in_multi_block_with_ops_.count(var_node->Name())) {
for (auto op_type : for (const auto& op_type :
vars_appear_multi_in_one_block_[block_idx].at(out->Name())) { vars_in_multi_block_with_ops_.at(var_node->Name())) {
if (OpSupportPrecision( if (!OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) { op_type, backend_, mixed_precision_, black_list_)) {
ret = true; VLOG(2) << var_node->Name()
VLOG(2) << out->Name()
<< " is multi precision op's out, so we skip convert to fp16"; << " is multi precision op's out, so we skip convert to fp16";
break; return true;
} }
} }
} }
if (ret) break;
} }
return ret; return false;
} }
void ConvertToMixedPrecisionPass::ProcessInputNode( void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision, bool support_precision,
ir::Node* in_node, framework::ir::Node* in_node,
ir::Node* op_node, framework::ir::Node* op_node,
int* suffix, int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type, VarType::Type to_type,
int block_idx) { BlockID block_idx) {
auto* real_node = GetRealNode(block_idx, in_node); if (!in_node->IsVar()) return;
if (!NodeVarHasDtype(real_node)) return; auto* real_node = GetRealVarNode(block_idx, in_node);
auto graph = graphes_[block_idx]; if (!VarNodeHasDtype(real_node)) return;
auto* graph = graphes_[block_idx];
bool is_main_block = block_idx == 0; bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var(); auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType(); auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type; auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name()); bool is_in_multi_block = vars_in_multi_block_with_pair_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) { if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first; in_var_type = vars_in_multi_block_with_pair_.at(in_var->Name()).first;
} }
if (support_precision) { if (support_precision) {
if (in_var->Persistable() && if (in_var->Persistable() && in_var_type == VarType::FP32) {
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return; if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type); in_var->SetDataType(to_type);
in_var_type = to_type; in_var_type = to_type;
...@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( ...@@ -300,14 +290,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
} }
void ConvertToMixedPrecisionPass::ProcessOutputNode( void ConvertToMixedPrecisionPass::ProcessOutputNode(
int block_idx, BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) {
ir::Node* var_node, if (!var_node->IsVar()) return;
framework::proto::VarType::Type to_type) { auto* real_node = GetRealVarNode(block_idx, var_node);
auto* real_node = GetRealNode(block_idx, var_node); if (!VarNodeHasDtype(real_node)) return;
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var(); auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType(); auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == framework::proto::VarType::FP32) { if (out_var->GetDataType() == VarType::FP32) {
if (OutShouldNotConvert(var_node)) return; if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type); out_var->SetDataType(to_type);
} }
...@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode( ...@@ -316,7 +305,8 @@ void ConvertToMixedPrecisionPass::ProcessOutputNode(
} }
// Just process special cases. // Just process special cases.
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { bool ConvertToMixedPrecisionPass::OutShouldNotConvert(
framework::ir::Node* var_node) {
auto op_node = var_node->inputs[0]; auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
...@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) { ...@@ -343,7 +333,8 @@ bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(
framework::ir::Node* var_node) {
auto op_nodes = var_node->outputs; auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) { for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op(); auto* op_desc = op_node->Op();
...@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) { ...@@ -391,13 +382,10 @@ bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(
framework::proto::VarType::Type type) { inline bool ConvertToMixedPrecisionPass::IsFloatVarType(VarType::Type type) {
if (type == framework::proto::VarType::FP16 || return (type == VarType::FP16) || (type == VarType::FP32) ||
type == framework::proto::VarType::FP32 || (type == VarType::BF16);
type == framework::proto::VarType::BF16)
return true;
return false;
} }
void ConvertToMixedPrecisionPass::LoadAndPrepare() { void ConvertToMixedPrecisionPass::LoadAndPrepare() {
...@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { ...@@ -405,6 +393,10 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference::Load(&executor_, &scope_, model_file_, params_file_); inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>( main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_)); new framework::ir::Graph(*program_desc_));
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto* graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
}
// Remove all control var // Remove all control var
IrInferCleanGraphPass pass; IrInferCleanGraphPass pass;
...@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { ...@@ -412,41 +404,45 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg.SetMainGraphNotOwned(main_graph_.get()); arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg); pass.Run(&arg);
vars_appear_multi_in_one_block_.resize(program_desc_->Size());
FindVarsInMultiBlock(); FindVarsInMultiBlock();
} }
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size()); std::unordered_set<std::string> all_var_names_set;
for (size_t i = 0; i < program_desc_->Size(); ++i) { std::vector<std::unordered_set<std::string>> block_var_names_set(
for (auto op : program_desc_->Block(i).AllOps()) { program_desc_->Size());
auto in_names = op->InputArgumentNames(); for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) {
block_var_names_set[i].insert(in_names.begin(), in_names.end()); for (auto* op : program_desc_->Block(idx).AllOps()) {
auto out_names = op->OutputArgumentNames(); const auto& in_names = op->InputArgumentNames();
block_var_names_set[idx].insert(in_names.begin(), in_names.end());
const auto& out_names = op->OutputArgumentNames();
block_var_names_set[idx].insert(out_names.begin(), out_names.end());
if (op->HasAttr("sub_block") == false) { if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) { for (const auto& name : out_names) {
if (block_var_names_set[i].count(n)) { if (all_var_names_set.count(name)) {
vars_appear_multi_in_one_block_[i][n].push_back(op->Type()); vars_in_multi_block_with_ops_[name].push_back(op->Type());
} }
} }
} }
block_var_names_set[i].insert(out_names.begin(), out_names.end()); all_var_names_set.insert(block_var_names_set[idx].begin(),
block_var_names_set[idx].end());
} }
} }
for (size_t i = 0; i < program_desc_->Size() - 1; ++i) { CHECK_GT(program_desc_->Size(), 0U);
for (size_t j = i + 1; j < program_desc_->Size(); ++j) { for (BlockID idx = 0; idx < program_desc_->Size() - 1; ++idx) {
std::set<std::string> vars_in_multi_block; for (BlockID jdx = idx + 1; jdx < program_desc_->Size(); ++jdx) {
std::set_intersection( std::vector<std::string> vars_in_multi_block;
block_var_names_set[i].begin(), std::set_intersection(block_var_names_set[idx].begin(),
block_var_names_set[i].end(), block_var_names_set[idx].end(),
block_var_names_set[j].begin(), block_var_names_set[jdx].begin(),
block_var_names_set[j].end(), block_var_names_set[jdx].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin())); std::back_inserter(vars_in_multi_block));
for (auto name : vars_in_multi_block) { for (const auto& name : vars_in_multi_block) {
vars_in_multi_block_map_.emplace( vars_in_multi_block_with_pair_.emplace(
name, std::make_pair(framework::proto::VarType::FP32, i)); name, std::make_pair(VarType::FP32, idx));
} }
} }
} }
...@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( ...@@ -462,41 +458,34 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
if (op_type == "fill_constant") { if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "assign_value") { } else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "eye") { } else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "fill_any_like") { } else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
"dtype", static_cast<int>(framework::proto::VarType::FP32));
} else if (op_type == "cast") { } else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("in_dtype", static_cast<int>(VarType::FP32));
"in_dtype", static_cast<int>(framework::proto::VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) == if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(framework::proto::VarType::FP64)) static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr( op_node->Op()->SetAttr("out_dtype", static_cast<int>(VarType::FP32));
"out_dtype", static_cast<int>(framework::proto::VarType::FP32));
} }
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (!in_var->Persistable() && if (!in_var->Persistable() && in_var->GetDataType() == VarType::FP64) {
in_var->GetDataType() == framework::proto::VarType::FP64) { in_var->SetDataType(VarType::FP32);
in_var->SetDataType(framework::proto::VarType::FP32);
} }
} }
} }
...@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( ...@@ -505,9 +494,8 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
void ConvertToMixedPrecisionPass::Run() { void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare(); LoadAndPrepare();
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { for (size_t i = 0; i < graphes_.size(); ++i) {
auto graph = main_graph_->GetSubGraph(i); auto* graph = graphes_[i];
graphes_.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has " VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes --------"; << graph->Nodes().size() << " nodes --------";
...@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() { ...@@ -518,19 +506,19 @@ void ConvertToMixedPrecisionPass::Run() {
// A trick // A trick
PatchForStrangeOp(); PatchForStrangeOp();
CHECK_EQ(ir::VarDescIsConsistency(*graph), true); CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true);
} }
SaveMixedModel(); SaveMixedModel();
} }
void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
auto graph = graphes_[block_idx]; auto* graph = graphes_[block_idx];
framework::proto::VarType::Type to_type; VarType::Type to_type;
if (mixed_precision_ == phi::DataType::FLOAT16) { if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16; to_type = VarType::FP16;
} else if (mixed_precision_ == phi::DataType::BFLOAT16) { } else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16; to_type = VarType::BF16;
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only " "mixed_precision currently not supported dtype %d, we now only "
...@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -551,8 +539,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 1. set input dtype. // 1. set input dtype.
if (op_type == "feed") { if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var(); auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types_ && if (!keep_io_types_ && feed_var->GetDataType() == VarType::FP32) {
feed_var->GetDataType() == framework::proto::VarType::FP32) {
feed_var->SetDataType(to_type); feed_var->SetDataType(to_type);
} }
} else if (op_type == "fetch") { } else if (op_type == "fetch") {
...@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -568,15 +555,17 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// same name. // same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node; std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) { for (auto* in : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in); if (!in->IsVar()) continue;
if (NodeVarHasDtype(real_node)) { auto* real_node = GetRealVarNode(block_idx, in);
if (VarNodeHasDtype(real_node)) {
in_name_to_node[in->Name()] = in; in_name_to_node[in->Name()] = in;
} }
} }
for (auto out : op_node->outputs) { for (auto* out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out); if (!out->IsVar()) continue;
if (NodeVarHasDtype(real_node)) { auto* real_node = GetRealVarNode(block_idx, out);
if (VarNodeHasDtype(real_node)) {
if (in_name_to_node.count(out->Name())) if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType( real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType()); in_name_to_node[out->Name()]->Var()->GetDataType());
...@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -591,32 +580,46 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp16/bf16. // - add cast op if the input dtype is not fp16/bf16.
// - set output dtype. // - set output dtype.
// //
// If a var(op's out var) appears multiple times in a block, we should not // If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16. // convert to fp16.
else if (black_list_.count(op_type) == 0 && // NOLINT else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) { !VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision = bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// if op not has float input, we will not choose the low precision kernel. // If the op has no input and output of float type, we will not choose the
// low precision kernel.
{ {
bool has_float_input{false}; bool has_float_input_and_output{false};
for (auto in_node : op_node->inputs) { for (auto* in_node : op_node->inputs) {
auto* real_node = GetRealNode(block_idx, in_node); if (!in_node->IsVar()) continue;
if (real_node->Var()->GetDataType() == proto::VarType::FP16 || auto* real_node = GetRealVarNode(block_idx, in_node);
real_node->Var()->GetDataType() == proto::VarType::FP32 || if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 || real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::BF16) { real_node->Var()->GetDataType() == VarType::FP64 ||
has_float_input = true; real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break; break;
} }
} }
if (!has_float_input) { for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input_and_output = true;
break;
}
}
if (!has_float_input_and_output) {
support_precision = false; support_precision = false;
VLOG(2) << " op doesn't has float input, just skip."; VLOG(2) << " op doesn't has float input and output, just skip.";
} }
} }
VLOG(2) << " support low precision " << support_precision; VLOG(2) << "op type: " << op_type
<< " support low precision: " << support_precision;
if (support_precision) { if (support_precision) {
VLOG(2) << " process input nodes:"; VLOG(2) << " process input nodes:";
...@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -626,8 +629,8 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// Just for paddle's terriable case: op's input and output has the same // Just for paddle's terriable case: op's input and output has the same
// name. // name.
std::unordered_map<std::string, std::string> names_map; std::unordered_map<std::string, std::string> names_map;
for (auto out_node : op_node->outputs) { for (auto* out_node : op_node->outputs) {
for (auto in_node : op_node->inputs) { for (auto* in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) { if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name(); names_map[out_node->Name()] = in_node->Name();
} }
...@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -655,7 +658,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
op_node, op_node,
&suffix_, &suffix_,
block_desc, block_desc,
framework::proto::VarType::FP32, VarType::FP32,
block_idx); block_idx);
} }
} }
...@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -665,21 +668,19 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// - add cast op if the input dtype is not fp32. // - add cast op if the input dtype is not fp32.
else { // NOLINT else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type; VLOG(3) << "not to run fp16 op_type: " << op_type;
auto ins = op_node->inputs; for (auto* in_node : op_node->inputs) {
for (auto* in_node : ins) {
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) { if (in_var->GetDataType() == to_type) {
AddCastOp(graph, AddCastOp(graph,
in_node, in_node,
op_node, op_node,
to_type, to_type,
framework::proto::VarType::FP32, VarType::FP32,
&suffix_, &suffix_,
block_desc, block_desc,
&cast_map_); &cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to " VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "(" << cast_map_[in_node]->Name() << "(" << VarType::FP32 << ")";
<< framework::proto::VarType::FP32 << ")";
} }
} }
} }
...@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -688,31 +689,30 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
// 4. if output_op's dtype is not compatible to output dtype, then just // 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast. // insert cast.
for (auto* node : output_nodes) { for (auto* node : output_nodes) {
ir::Node* fetch_op{nullptr}; framework::ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) { for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") { if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node; fetch_op = op_node;
} }
} }
CHECK_NOTNULL(fetch_op); CHECK_NOTNULL(fetch_op);
auto var = node->Var(); auto* var = node->Var();
if (keep_io_types_ && var->GetDataType() == to_type) { if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32. // fp16/bf16 -> fp32.
AddCastOp(graph, AddCastOp(graph,
node, node,
fetch_op, fetch_op,
to_type, to_type,
framework::proto::VarType::FP32, VarType::FP32,
&suffix_, &suffix_,
block_desc, block_desc,
&cast_map_); &cast_map_);
} else if (!keep_io_types_ && } else if (!keep_io_types_ && var->GetDataType() == VarType::FP32) {
var->GetDataType() == framework::proto::VarType::FP32) {
// fp32 -> fp16/bf16 // fp32 -> fp16/bf16
AddCastOp(graph, AddCastOp(graph,
node, node,
fetch_op, fetch_op,
framework::proto::VarType::FP32, VarType::FP32,
to_type, to_type,
&suffix_, &suffix_,
block_desc, block_desc,
...@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) { ...@@ -720,13 +720,15 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
} }
} }
for (auto node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
auto* real_node = GetRealNode(block_idx, node); if (!node->IsVar()) continue;
if (!NodeVarHasDtype(real_node)) continue; auto* real_node = GetRealVarNode(block_idx, node);
if (!VarNodeHasDtype(real_node)) continue;
if (vars_in_multi_block_map_.count(real_node->Name()) && if (vars_in_multi_block_with_pair_.count(real_node->Name()) &&
vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) { vars_in_multi_block_with_pair_.at(real_node->Name()).second ==
vars_in_multi_block_map_.at(real_node->Name()).first = block_idx) {
vars_in_multi_block_with_pair_.at(real_node->Name()).first =
real_node->Var()->GetDataType(); real_node->Var()->GetDataType();
} }
} }
...@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -757,17 +759,15 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc; framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc); framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
paddle::CPUPlace place;
auto parameters = scope_.LocalVarNames(); auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end()); std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32; std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) { for (auto* node : main_graph_->Nodes()) {
if (!(node->IsVar())) continue; if (!node->IsVar()) continue;
if (NodeVarHasDtype(node)) { if (VarNodeHasDtype(node)) {
if (node->Var()->Persistable() && if (node->Var()->Persistable() &&
node->Var()->GetDataType() == node->Var()->GetDataType() == VarType::FP32) {
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name(); VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name()); weights_should_be_fp32.insert(node->Name());
} }
...@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -777,26 +777,27 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ #define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \ mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \ auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \ for (int64_t i = 0; i < origin_tensor->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \ mixed_data[i] = static_cast<dtype>(origin_data[i]); \
} \ } \
t->clear(); \ origin_tensor->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t) paddle::framework::TensorCopySync( \
mixed_tensor, platform::CPUPlace(), origin_tensor)
for (const auto& param_name : parameters) { for (const auto& param_name : parameters) {
if (weights_should_be_fp32.count(param_name)) continue;
auto* var = scope_.FindLocalVar(param_name); auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
auto* t = var->GetMutable<phi::DenseTensor>(); auto* origin_tensor = var->GetMutable<phi::DenseTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue; if (origin_tensor->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor; phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(t->dims()); mixed_tensor.Resize(origin_tensor->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace()); auto* origin_data =
if (mixed_precision_ == phi::DataType::FLOAT16 && origin_tensor->mutable_data<float>(platform::CPUPlace());
!weights_should_be_fp32.count(param_name)) { if (mixed_precision_ == phi::DataType::FLOAT16) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16); phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16 && } else if (mixed_precision_ == phi::DataType::BFLOAT16) {
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16); phi::dtype::bfloat16);
} }
...@@ -851,8 +852,8 @@ void AddCastOp( ...@@ -851,8 +852,8 @@ void AddCastOp(
framework::ir::Graph* graph, framework::ir::Graph* graph,
framework::ir::Node* node, framework::ir::Node* node,
framework::ir::Node* next_op, framework::ir::Node* next_op,
framework::proto::VarType::Type from_type, VarType::Type from_type,
framework::proto::VarType::Type to_type, VarType::Type to_type,
int* suffix, int* suffix,
framework::BlockDesc* block_desc, framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) { std::unordered_map<framework::ir::Node*, framework::ir::Node*>* map) {
...@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type, ...@@ -913,14 +914,15 @@ bool OpSupportPrecision(const std::string& op_type,
return support_precision; return support_precision;
} }
void ConvertToMixedPrecision(const std::string& model_file, void ConvertToMixedPrecision(
const std::string& params_file, const std::string& model_file,
const std::string& mixed_model_file, const std::string& params_file,
const std::string& mixed_params_file, const std::string& mixed_model_file,
phi::DataType mixed_precision, const std::string& mixed_params_file,
phi::Backend backend, phi::DataType mixed_precision,
bool keep_io_types, phi::Backend backend,
std::unordered_set<std::string> black_list) { bool keep_io_types,
const std::unordered_set<std::string>& black_list) {
ConvertToMixedPrecisionPass pass(model_file, ConvertToMixedPrecisionPass pass(model_file,
params_file, params_file,
mixed_model_file, mixed_model_file,
......
...@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -51,8 +51,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
const std::string& mixed_params_file, const std::string& mixed_params_file,
phi::DataType mixed_precision, phi::DataType mixed_precision,
phi::Backend backend, phi::Backend backend,
bool keep_io_types = true, bool keep_io_types,
std::unordered_set<std::string> black_list = {}); const std::unordered_set<std::string>& black_list);
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册