未验证 提交 de6f15b6 编写于 作者: W Wilber 提交者: GitHub

reconstruct code for convert_fp16 (#46428) (#47087)

上级 2cc8797e
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -31,9 +32,14 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_clean_pass.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
using namespace paddle::framework; // NOLINT
......@@ -43,160 +49,6 @@ namespace inference {
namespace analysis {
namespace {
inline std::string SerializeParams(framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : params) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(param),
platform::errors::NotFound("Block should already have a '%s' variable",
param));
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
}
inline void StrToBinary(const std::string& path, const std::string& str) {
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
}
inline bool NodeVarHasDtype(framework::ir::Node* node) {
if (node->IsCtrlVar()) return false;
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
if (vars_in_multi_block_map->count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map->at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool VarIsMultiOpsOut(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
CHECK_EQ(op_node->IsOp(), true);
for (auto* out : op_node->outputs) {
if (out->IsCtrlVar()) continue;
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block[block_idx].count(out->Name())) {
VLOG(2) << out->Name()
<< " is multi op's out, so we skip convert to fp16";
return true;
}
}
return false;
}
void SaveMixedModel(
framework::ir::Graph* graph,
framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
const std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>&
vars_in_multi_block_map) {
paddle::CPUPlace place;
auto parameters = scope->LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
for (const auto& param_name : parameters) {
auto* var = scope->FindLocalVar(param_name);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto* t = var->GetMutable<framework::LoDTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue;
framework::Tensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
} else if (mixed_precision == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
}
}
}
StrToBinary(mixed_model_file,
mixed_program_desc->Proto()->SerializeAsString());
StrToBinary(mixed_params_file, SerializeParams(scope, parameters));
}
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
......@@ -235,8 +87,236 @@ bool GpuKernelSupportPrecision(
return res;
}
class ConvertToMixedPrecisionPass {
public:
explicit ConvertToMixedPrecisionPass(
const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list)
: model_file_(model_file),
params_file_(params_file),
mixed_model_file_(mixed_model_file),
mixed_params_file_(mixed_params_file),
mixed_precision_(mixed_precision),
backend_(backend),
keep_io_types_(keep_io_types),
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {
black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
}
void Run();
private:
void LoadAndPrepare();
inline bool NodeVarHasDtype(framework::ir::Node* node);
void ConvertAllFp64ToFp32(framework::ir::Graph* graph);
void FixCastAttr(framework::ir::Graph* graph);
void SaveMixedModel();
void ConvertTensorDtype(int block_idx);
void ProcessInputNode(bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx);
void ProcessOutputNode(int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type);
inline bool IsFloatVarType(framework::proto::VarType::Type type);
bool OutShouldNotConvert(ir::Node* var_node);
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(int block_idx, framework::ir::Node* node);
void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(int block_idx,
framework::ir::Node* op_node);
private:
// A trick. Patch for strange op, which input name equal to output name, such
// as `fused_multi_transformer`
void PatchForStrangeOp();
private:
std::string model_file_;
std::string params_file_;
std::string mixed_model_file_;
std::string mixed_params_file_;
phi::DataType mixed_precision_;
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
paddle::CPUPlace place_;
framework::Executor executor_;
framework::Scope scope_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map_;
std::vector<std::unordered_map<std::string, std::vector<std::string>>>
vars_appear_multi_in_one_block_;
int suffix_{0};
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
std::vector<framework::ir::Graph*> graphes_;
};
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealNode(
int block_idx, framework::ir::Node* node) {
if (vars_in_multi_block_map_.count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map_.at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes_[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool ConvertToMixedPrecisionPass::NodeVarHasDtype(
framework::ir::Node* node) {
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
int block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true);
bool ret{false};
for (auto* out : op_node->outputs) {
auto* real_node = GetRealNode(block_idx, out);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block_[block_idx].count(out->Name())) {
for (auto op_type :
vars_appear_multi_in_one_block_[block_idx].at(out->Name())) {
if (OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) {
ret = true;
VLOG(2) << out->Name()
<< " is multi precision op's out, so we skip convert to fp16";
break;
}
}
}
if (ret) break;
}
return ret;
}
void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
framework::proto::VarType::Type to_type,
int block_idx) {
auto* real_node = GetRealNode(block_idx, in_node);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes_[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_map_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map_.at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
VLOG(3) << " in_node name " << in_var->Name() << " from " << prev_type
<< " to " << to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
&cast_map_);
VLOG(3) << " in_node name " << in_var->Name() << "(" << prev_type
<< ") to " << cast_map_[in_node]->Name() << "(" << to_type << ")";
}
}
}
void ConvertToMixedPrecisionPass::ProcessOutputNode(
int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type) {
auto* real_node = GetRealNode(block_idx, var_node);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
VLOG(3) << " out_node name " << var_node->Name() << " from dtype "
<< prev_type << " to " << out_var->GetDataType();
}
// Just process special cases.
bool OutShouldNotConvert(ir::Node* var_node) {
bool ConvertToMixedPrecisionPass::OutShouldNotConvert(ir::Node* var_node) {
auto op_node = var_node->inputs[0];
auto* op_desc = op_node->Op();
......@@ -262,28 +342,8 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false;
}
void ProcessOutputNode(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
auto* real_node =
GetRealNode(graphes, block_idx, var_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
VLOG(3) << " out_node name " << var_node->Name() << " data_type "
<< out_var->GetDataType();
}
// Just process special cases for weights conversion.
bool WeightsShouldNotConvert(ir::Node* var_node) {
bool ConvertToMixedPrecisionPass::WeightsShouldNotConvert(ir::Node* var_node) {
auto op_nodes = var_node->outputs;
for (auto* op_node : op_nodes) {
auto* op_desc = op_node->Op();
......@@ -331,72 +391,69 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
return false;
}
inline bool IsFloatVarType(framework::proto::VarType::Type type) {
inline bool ConvertToMixedPrecisionPass::IsFloatVarType(
framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16)
return true;
return false;
}
void ProcessInputNode(
bool support_precision,
std::vector<framework::ir::Graph*> graphes,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map,
framework::proto::VarType::Type to_type,
int block_idx,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
auto* real_node =
GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
bool is_in_multi_block = vars_in_multi_block_map->count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map->at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
in_var_type = to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
void ConvertToMixedPrecisionPass::LoadAndPrepare() {
program_desc_ =
inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_));
// Remove all control var
IrInferCleanGraphPass pass;
Argument arg;
arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg);
vars_appear_multi_in_one_block_.resize(program_desc_->Size());
FindVarsInMultiBlock();
}
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
for (size_t i = 0; i < program_desc_->Size(); ++i) {
for (auto op : program_desc_->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
vars_appear_multi_in_one_block_[i][n].push_back(op->Type());
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
}
for (size_t i = 0; i < program_desc_->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc_->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map_.emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
VLOG(3) << " in_node name " << in_var->Name() << " data_type " << in_var_type;
}
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
......@@ -436,7 +493,6 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP64) {
......@@ -446,158 +502,47 @@ void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
}
}
// Handle special ops which contains dtype attribute. e.g., fill_constant,
// assign_value.
void HandleSpecialOps(framework::OpDesc* op_desc) {
if (op_desc->Type() == "fill_constant") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "assign_value") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "eye") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_constant_batch_size_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
}
}
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
}
void ConvertToMixedPrecisionPass::Run() {
LoadAndPrepare();
void FindVarsInMultiBlock(
framework::ProgramDesc* program_desc,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
std::vector<std::set<std::string>>* vars_appear_multi_in_one_block) {
std::vector<std::set<std::string>> block_var_names_set(program_desc->Size());
for (size_t i = 0; i < program_desc->Size(); ++i) {
for (auto op : program_desc->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
(*vars_appear_multi_in_one_block)[i].insert(n);
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
}
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes --------";
for (size_t i = 0; i < program_desc->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(i);
FixCastAttr(graph);
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
}
// A trick
PatchForStrangeOp();
bool OpInOutHasTensorArray(
std::vector<framework::ir::Graph*> graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
CHECK_EQ(op_node->IsOp(), true);
for (auto in : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
CHECK_EQ(ir::VarDescIsConsistency(*graph), true);
}
for (auto out : op_node->outputs) {
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
}
return false;
SaveMixedModel();
}
void ConvertTensorDtype(
framework::ProgramDesc* program_desc,
std::vector<framework::ir::Graph*> graphes,
const std::unordered_set<std::string>& blacklist,
bool keep_io_types,
phi::Backend backend,
phi::DataType tensor_dtype,
int block_idx,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
auto graph = graphes[block_idx];
void ConvertToMixedPrecisionPass::ConvertTensorDtype(int block_idx) {
auto graph = graphes_[block_idx];
framework::proto::VarType::Type to_type;
if (tensor_dtype == phi::DataType::FLOAT16) {
if (mixed_precision_ == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
} else if (tensor_dtype == phi::DataType::BFLOAT16) {
} else if (mixed_precision_ == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support fp16 and bf16.",
static_cast<int>(tensor_dtype)));
static_cast<int>(mixed_precision_)));
}
auto* block_desc =
framework::ir::TopologySortOperations(*graph)[0]->Op()->Block();
auto op_nodes = framework::ir::TopologySortOperations(*graph);
auto* block_desc = op_nodes[0]->Op()->Block();
int num_low_precision = 0;
int suffix = 0;
std::vector<framework::ir::Node*> output_nodes;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map;
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
......@@ -606,7 +551,7 @@ void ConvertTensorDtype(
// 1. set input dtype.
if (op_type == "feed") {
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types &&
if (!keep_io_types_ &&
feed_var->GetDataType() == framework::proto::VarType::FP32) {
feed_var->SetDataType(to_type);
}
......@@ -623,16 +568,14 @@ void ConvertTensorDtype(
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
auto* real_node = GetRealNode(block_idx, in);
if (NodeVarHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
auto* real_node = GetRealNode(block_idx, out);
if (NodeVarHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType(
......@@ -643,23 +586,6 @@ void ConvertTensorDtype(
continue;
}
// A strange case found in multi block.
else if (op_type == "assign" && // NOLINT
op_node->inputs[0]->Name() == op_node->outputs[0]->Name()) {
VLOG(2) << " in out are same, continue";
continue;
}
// Handle tensor array.
else if (OpInOutHasTensorArray( // NOLINT
graphes,
block_idx,
op_node,
vars_in_multi_block_map)) {
VLOG(2) << " in or out has tensor array, continue";
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
......@@ -667,22 +593,16 @@ void ConvertTensorDtype(
//
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
else if (blacklist.count(op_type) == 0 && // NOLINT
!VarIsMultiOpsOut(graphes,
block_idx,
op_node,
vars_in_multi_block_map,
vars_appear_multi_in_one_block)) {
else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision =
OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << " support low precision " << support_precision;
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// if op not has float input, we will not choose the low precision kernel.
{
bool has_float_input{false};
for (auto in_node : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map);
auto* real_node = GetRealNode(block_idx, in_node);
if (real_node->Var()->GetDataType() == proto::VarType::FP16 ||
real_node->Var()->GetDataType() == proto::VarType::FP32 ||
real_node->Var()->GetDataType() == proto::VarType::FP64 ||
......@@ -696,42 +616,47 @@ void ConvertTensorDtype(
VLOG(2) << " op doesn't has float input, just skip.";
}
}
VLOG(2) << " support low precision " << support_precision;
if (support_precision) {
HandleSpecialOps(op_node->Op());
VLOG(2) << " process input nodes:";
++num_low_precision;
auto inputs = op_node->inputs;
// Just for paddle's terriable case: op's input and output has the same
// name.
std::unordered_map<std::string, std::string> names_map;
for (auto out_node : op_node->outputs) {
for (auto in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name();
}
}
}
// Process inputs.
for (auto* in_node : inputs) {
ProcessInputNode(true,
graphes,
in_node,
op_node,
&suffix,
block_desc,
&cast_map,
to_type,
block_idx,
vars_in_multi_block_map);
ProcessInputNode(
true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) {
names_map[in_node->Name()] = cast_map_[in_node]->Name();
}
}
VLOG(2) << " process output nodes:";
// Process outputs.
for (auto* out_node : op_node->outputs) {
ProcessOutputNode(
graphes, block_idx, out_node, to_type, vars_in_multi_block_map);
ProcessOutputNode(block_idx, out_node, to_type);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(false,
graphes,
in_node,
op_node,
&suffix,
&suffix_,
block_desc,
&cast_map,
framework::proto::VarType::FP32,
block_idx,
vars_in_multi_block_map);
block_idx);
}
}
}
......@@ -739,9 +664,9 @@ void ConvertTensorDtype(
// 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32.
else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type;
auto ins = op_node->inputs;
for (auto* in_node : ins) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
......@@ -749,9 +674,12 @@ void ConvertTensorDtype(
op_node,
to_type,
framework::proto::VarType::FP32,
&suffix,
&suffix_,
block_desc,
&cast_map);
&cast_map_);
VLOG(3) << "-- " << in_node->Name() << "(" << to_type << ") to "
<< cast_map_[in_node]->Name() << "("
<< framework::proto::VarType::FP32 << ")";
}
}
}
......@@ -760,40 +688,45 @@ void ConvertTensorDtype(
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
if (node->IsCtrlVar()) continue;
ir::Node* fetch_op{nullptr};
for (auto* op_node : node->outputs) {
if (op_node->IsOp() && op_node->Op()->Type() == "fetch") {
fetch_op = op_node;
}
}
CHECK_NOTNULL(fetch_op);
auto var = node->Var();
if (keep_io_types && var->GetDataType() == to_type) {
if (keep_io_types_ && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
AddCastOp(graph,
node,
node->outputs[0],
fetch_op,
to_type,
framework::proto::VarType::FP32,
&suffix,
&suffix_,
block_desc,
&cast_map);
} else if (!keep_io_types &&
&cast_map_);
} else if (!keep_io_types_ &&
var->GetDataType() == framework::proto::VarType::FP32) {
// fp32 -> fp16/bf16
AddCastOp(graph,
node,
node->outputs[0],
fetch_op,
framework::proto::VarType::FP32,
to_type,
&suffix,
&suffix_,
block_desc,
&cast_map);
&cast_map_);
}
}
for (auto node : graph->Nodes()) {
auto* real_node =
GetRealNode(graphes, block_idx, node, vars_in_multi_block_map);
auto* real_node = GetRealNode(block_idx, node);
if (!NodeVarHasDtype(real_node)) continue;
if (vars_in_multi_block_map->count(real_node->Name()) &&
vars_in_multi_block_map->at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map->at(real_node->Name()).first =
if (vars_in_multi_block_map_.count(real_node->Name()) &&
vars_in_multi_block_map_.at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
......@@ -802,24 +735,118 @@ void ConvertTensorDtype(
LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
}
} // namespace
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false;
if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(op_type, precision);
else
support_precision =
PhiKernelSupportPrecision(phi_op_type, backend, precision);
// We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute.
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type != "cast") continue;
auto input = op_node->inputs[0];
auto output = op_node->outputs[0];
op_node->Op()->SetAttr("in_dtype",
static_cast<int>(input->Var()->GetDataType()));
op_node->Op()->SetAttr("out_dtype",
static_cast<int>(output->Var()->GetDataType()));
}
return support_precision;
}
void ConvertToMixedPrecisionPass::SaveMixedModel() {
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);
paddle::CPUPlace place;
auto parameters = scope_.LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : main_graph_->Nodes()) {
if (!(node->IsVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \
mixed_tensor.set_type(DTYPE); \
auto* mixed_data = mixed_tensor.mutable_data<dtype>(platform::CPUPlace()); \
for (int i = 0; i < t->numel(); i++) { \
mixed_data[i] = static_cast<dtype>(data[i]); \
} \
t->clear(); \
paddle::framework::TensorCopySync(mixed_tensor, place, t)
for (const auto& param_name : parameters) {
auto* var = scope_.FindLocalVar(param_name);
if (var->IsType<phi::DenseTensor>()) {
auto* t = var->GetMutable<phi::DenseTensor>();
if (t->dtype() != phi::DataType::FLOAT32) continue;
phi::DenseTensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision_ == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16,
phi::dtype::float16);
} else if (mixed_precision_ == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16,
phi::dtype::bfloat16);
}
}
}
#undef CONVERT_TENSOR_DTYPE
auto SerializeParams = [&]() -> std::string {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : parameters) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope_.FindVar(param),
platform::errors::NotFound(
"Block should already have a '%s' variable", param));
auto* tensor = scope_.FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
};
auto StrToBinary = [](const std::string& path, const std::string& str) {
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
};
StrToBinary(mixed_model_file_,
mixed_program_desc.Proto()->SerializeAsString());
StrToBinary(mixed_params_file_, SerializeParams());
}
void ConvertToMixedPrecisionPass::PatchForStrangeOp() {
for (auto* graph : graphes_) {
for (auto op_node : framework::ir::TopologySortOperations(*graph)) {
if (op_node->Name() == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
}
} // namespace
void AddCastOp(
framework::ir::Graph* graph,
framework::ir::Node* node,
......@@ -865,11 +892,27 @@ void AddCastOp(
IR_NODE_LINK_TO(cast_op_node, cast_output_node);
(*map)[node] = cast_output_node;
}
next_op->Op()->RenameInput(node->Name(), map->at(node)->Name());
next_op->Op()->Rename(node->Name(), map->at(node)->Name());
IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
IR_NODE_LINK_TO(map->at(node), next_op);
}
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false;
if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(op_type, precision);
else
support_precision =
PhiKernelSupportPrecision(phi_op_type, backend, precision);
}
return support_precision;
}
void ConvertToMixedPrecision(const std::string& model_file,
const std::string& params_file,
const std::string& mixed_model_file,
......@@ -878,53 +921,15 @@ void ConvertToMixedPrecision(const std::string& model_file,
phi::Backend backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
paddle::CPUPlace place;
framework::Executor executor(place);
framework::Scope scope;
auto program_desc =
inference::Load(&executor, &scope, model_file, params_file);
auto main_graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc));
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map;
std::vector<std::set<std::string>> vars_appear_multi_in_one_block(
program_desc->Size());
FindVarsInMultiBlock(program_desc.get(),
&vars_in_multi_block_map,
&vars_appear_multi_in_one_block);
std::vector<framework::ir::Graph*> graphes;
for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) {
auto graph = main_graph->GetSubGraph(i);
graphes.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes --------";
ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(program_desc.get(),
graphes,
black_list,
keep_io_types,
backend,
mixed_precision,
i,
&vars_in_multi_block_map,
vars_appear_multi_in_one_block);
FixCastAttr(graph);
}
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph, &mixed_program_desc);
SaveMixedModel(main_graph.get(),
&scope,
&mixed_program_desc,
mixed_model_file,
mixed_params_file,
mixed_precision,
vars_in_multi_block_map);
ConvertToMixedPrecisionPass pass(model_file,
params_file,
mixed_model_file,
mixed_params_file,
mixed_precision,
backend,
keep_io_types,
black_list);
pass.Run();
}
} // namespace analysis
......
......@@ -30,7 +30,7 @@ namespace paddle {
namespace inference {
namespace analysis {
bool OpSupportPrecision(const std::string& phi_op_type,
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist);
......
......@@ -140,39 +140,12 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
auto var_data_type = var_node->Var()->GetDataType();
VLOG(5) << "var_name is " << var_name << ", data type is "
<< var_data_type;
if (var_data_type == paddle::framework::proto::VarType::FP16 &&
t->dtype() != paddle::experimental::DataType::FLOAT16) {
framework::Tensor half_tensor;
half_tensor.set_type(paddle::experimental::DataType::FLOAT16);
half_tensor.Resize(t->dims());
auto *half_data =
half_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<float16>(platform::CPUPlace());
half_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(half_tensor, place, t);
} else if (var_data_type == paddle::framework::proto::VarType::BF16) {
framework::Tensor bf16_tensor;
bf16_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
bf16_tensor.Resize(t->dims());
auto *bf16_data = bf16_tensor.mutable_data<platform::bfloat16>(
platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
auto *data = t->mutable_data<bfloat16>(platform::CPUPlace());
bf16_data[i] = static_cast<platform::bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(bf16_tensor, place, t);
} else {
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t->dims());
paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册