未验证 提交 b4a4eef2 编写于 作者: W Wilber 提交者: GitHub

convert support multi block. (#44866)

* convert support multi block.

* update
上级 f9e7fe66
......@@ -14,7 +14,10 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/block_desc.h"
......@@ -39,7 +42,106 @@ namespace analysis {
namespace {
bool IsKernelSupportPrecision(
inline std::string SerializeParams(framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : params) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(param),
platform::errors::NotFound("Block should already have a '%s' variable",
param));
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
}
inline void StrToBinary(const std::string& path, const std::string& str) {
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
}
inline bool NodeVarHasDtype(framework::ir::Node* node) {
if (node->IsCtrlVar()) return false;
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
void SaveMixedModel(framework::ir::Graph* graph,
framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision) {
paddle::CPUPlace place;
auto parameters = scope->LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
for (const auto& param_name : parameters) {
auto* var = scope->FindLocalVar(param_name);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto* t = var->GetMutable<framework::LoDTensor>();
framework::Tensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
} else if (mixed_precision == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
}
}
}
StrToBinary(mixed_model_file,
mixed_program_desc->Proto()->SerializeAsString());
StrToBinary(mixed_params_file, SerializeParams(scope, parameters));
}
bool PhiKernelSupportPrecision(
const std::string& op_type,
phi::Backend backend,
phi::DataType data_type,
......@@ -56,10 +158,23 @@ bool GpuKernelSupportPrecision(
const std::string& op_type,
phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
bool res =
IsKernelSupportPrecision(op_type, phi::Backend::GPU, data_type, layout);
res |= IsKernelSupportPrecision(
op_type, phi::Backend::GPUDNN, data_type, layout);
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool res = PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPU, data_type, layout);
res |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
res = true;
}
}
}
}
return res;
}
......@@ -90,30 +205,16 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false;
}
// Get weight names which appear in multiple block (block 0 and block n).
std::unordered_set<std::string> GetMultiBlockPersistableNames(
framework::ProgramDesc* program_desc) {
std::unordered_set<std::string> special_weights;
size_t block_size = program_desc->Size();
std::unordered_set<std::string> block_0_weights;
for (auto var : program_desc->Block(0).AllVars()) {
if (var->Persistable()) block_0_weights.insert(var->Name());
}
for (size_t i = 1; i < block_size; ++i) {
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto all_ops = program_desc->Block(i).AllOps();
for (auto op : all_ops) {
for (auto name : op->InputArgumentNames()) {
if (block_0_weights.count(name)) special_weights.insert(name);
}
}
void ProcessOutputNode(ir::Node* var_node,
framework::proto::VarType::Type to_type) {
if (!NodeVarHasDtype(var_node)) return;
auto* out_var = var_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
}
return special_weights;
VLOG(3) << " out_node name " << var_node->Name() << " data_type "
<< out_var->GetDataType();
}
// Just process special cases for weights conversion.
......@@ -143,21 +244,8 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
}
}
// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for (auto* op_node : op_nodes) {
for (auto var : op_node->outputs) {
for (auto next_op : var->outputs) {
if (next_op->Op()->HasAttr("sub_block")) {
return true;
}
}
}
}
return false;
}
inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
......@@ -165,6 +253,56 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) {
return true;
return false;
}
void ProcessInputNode(
bool support_precision,
framework::ir::Graph* graph,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map,
framework::proto::VarType::Type to_type,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
if (!NodeVarHasDtype(in_node)) return;
auto* in_var = in_node->Var();
auto in_var_type = in_var->GetDataType();
if (!is_main_block && vars_in_multi_block_map->count(in_var->Name())) {
in_var_type = vars_in_multi_block_map->at(in_var->Name());
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
}
}
VLOG(3) << " in_node name " << in_var->Name() << " data_type "
<< in_var->GetDataType();
}
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph);
......@@ -239,6 +377,11 @@ void HandleSpecialOps(framework::OpDesc* op_desc) {
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_constant_batch_size_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
}
}
......@@ -260,26 +403,47 @@ void FixCastAttr(framework::ir::Graph* graph) {
}
}
// If op's output var is condition flow op's input, then the op must be fp32
// precision.
bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) {
auto cur_op_outs = cur_op_node->outputs;
for (auto out_var : cur_op_outs) {
for (auto next_op_node : out_var->outputs) {
if (next_op_node->Op()->HasAttr("sub_block")) {
return true;
void FindVarsInMultiBlock(
framework::ProgramDesc* program_desc,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
std::set<std::string> vars_in_multi_block;
std::set<std::string> main_block_var_names_set;
for (auto op : program_desc->Block(0).AllOps()) {
auto in_names = op->InputArgumentNames();
main_block_var_names_set.insert(in_names.begin(), in_names.end());
}
for (size_t i = 1; i < program_desc->Size(); ++i) {
std::set<std::string> block_var_names_set;
for (auto op : program_desc->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set.insert(in_names.begin(), in_names.end());
}
std::set_intersection(
main_block_var_names_set.begin(),
main_block_var_names_set.end(),
block_var_names_set.begin(),
block_var_names_set.end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
}
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(name, framework::proto::VarType::FP32);
}
return false;
}
void ConvertTensorDtype(framework::ProgramDesc* program_desc,
void ConvertTensorDtype(
framework::ProgramDesc* program_desc,
framework::ir::Graph* graph,
const std::unordered_set<std::string>& blacklist,
bool keep_io_types,
phi::Backend backend,
phi::DataType tensor_dtype) {
phi::DataType tensor_dtype,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
framework::proto::VarType::Type to_type;
if (tensor_dtype == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
......@@ -287,25 +451,27 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
to_type = framework::proto::VarType::BF16;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only support "
"mixed_precision currently not supported dtype %d, we now only "
"support "
"fp16 and bf16.",
static_cast<int>(tensor_dtype)));
}
auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc);
auto* block_desc =
framework::ir::TopologySortOperations(*graph)[0]->Op()->Block();
int num_low_precision = 0;
int suffix = 0;
framework::BlockDesc* block_desc{nullptr};
std::vector<framework::ir::Node*> output_nodes;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map;
auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type);
VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
<< phi::TransToPhiKernelName(op_type);
// 1. set input dtype.
if (op_type == "feed") {
block_desc = op_node->Op()->Block();
auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types &&
feed_var->GetDataType() == framework::proto::VarType::FP32) {
......@@ -319,71 +485,73 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
continue;
}
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
if (NodeVarHasDtype(in)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
if (NodeVarHasDtype(out)) {
if (in_name_to_node.count(out->Name()))
out->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else if (blacklist.count(phi_op_type) == 0 && // NOLINT
!NextOpIncludesConditionFlowOp(op_node)) {
else if (blacklist.count(op_type) == 0) { // NOLINT
bool support_precision =
OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type
<< " support low precision " << support_precision << ", "
OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "op_type " << op_type << ", phi_op_type "
<< phi::TransToPhiKernelName(op_type) << " support low precision "
<< support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block());
for (auto in_node : op_node->inputs) {
if (weight_name_in_multi_block.count(in_node->Name()))
support_precision = false;
}
if (support_precision) {
HandleSpecialOps(op_node->Op());
++num_low_precision;
auto inputs = op_node->inputs;
// Process inputs.
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) continue;
in_var->SetDataType(to_type);
} else if (!in_var->Persistable() &&
IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != to_type) {
AddCastOp(graph,
ProcessInputNode(true,
graph,
in_node,
op_node,
in_var->GetDataType(),
to_type,
&suffix,
block_desc,
&cast_map);
}
&cast_map,
to_type,
is_main_block,
vars_in_multi_block_map);
}
// Process outputs.
for (auto* out_node : op_node->outputs) {
if (out_node->IsCtrlVar()) continue;
auto* out_var = out_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(out_node)) continue;
out_var->SetDataType(to_type);
}
ProcessOutputNode(out_node, to_type);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != framework::proto::VarType::FP32) {
AddCastOp(graph,
ProcessInputNode(false,
graph,
in_node,
op_node,
in_var->GetDataType(),
framework::proto::VarType::FP32,
&suffix,
block_desc,
&cast_map);
}
&cast_map,
framework::proto::VarType::FP32,
is_main_block,
vars_in_multi_block_map);
}
}
}
......@@ -409,8 +577,8 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
}
}
// 4. if output_op's dtype is not compatible to output dtype, then just insert
// cast.
// 4. if output_op's dtype is not compatible to output dtype, then just
// insert cast.
for (auto* node : output_nodes) {
if (node->IsCtrlVar()) continue;
auto var = node->Var();
......@@ -438,22 +606,31 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
}
}
if (is_main_block) {
for (auto node : graph->Nodes()) {
if (vars_in_multi_block_map->count(node->Name())) {
vars_in_multi_block_map->at(node->Name()) = node->Var()->GetDataType();
}
}
}
if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision << " low precision ops";
}
} // namespace
bool OpSupportPrecision(const std::string& phi_op_type,
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false;
if (blacklist.count(phi_op_type) == 0) {
if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(phi_op_type, precision);
support_precision = GpuKernelSupportPrecision(op_type, precision);
else
support_precision =
IsKernelSupportPrecision(phi_op_type, backend, precision);
PhiKernelSupportPrecision(phi_op_type, backend, precision);
}
return support_precision;
}
......@@ -521,102 +698,41 @@ void ConvertToMixedPrecision(const std::string& model_file,
framework::Scope scope;
auto program_desc =
inference::Load(&executor, &scope, model_file, params_file);
auto graph = std::unique_ptr<framework::ir::Graph>(
auto main_graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc));
ConvertAllFp64ToFp32(graph.get());
std::unordered_map<std::string, framework::proto::VarType::Type>
vars_in_multi_block_map;
FindVarsInMultiBlock(program_desc.get(), &vars_in_multi_block_map);
for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) {
auto graph = main_graph->GetSubGraph(i);
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes";
program_desc->Block(i).LocalVarNames();
ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(program_desc.get(),
graph.get(),
graph,
black_list,
keep_io_types,
backend,
mixed_precision);
FixCastAttr(graph.get());
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*graph, &mixed_program_desc);
auto parameters = scope.LocalVarNames();
std::sort(parameters.begin(), parameters.end());
auto serialize_params =
[](framework::Scope* scope,
const std::vector<std::string>& params) -> std::string {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : params) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(param),
platform::errors::NotFound(
"Block should already have a '%s' variable", param));
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
mixed_precision,
i == 0,
&vars_in_multi_block_map);
FixCastAttr(graph);
}
return os.str();
};
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
for (const auto& param_name : parameters) {
auto* var = scope.FindLocalVar(param_name);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto* t = var->GetMutable<framework::LoDTensor>();
framework::Tensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
} else if (mixed_precision == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
}
}
}
framework::ProgramDesc mixed_program_desc;
framework::ir::GraphToProgram(*main_graph, &mixed_program_desc);
auto StrToBinary = [](const std::string& path, const std::string& str) {
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
};
StrToBinary(mixed_model_file,
mixed_program_desc.Proto()->SerializeAsString());
StrToBinary(mixed_params_file, serialize_params(&scope, parameters));
SaveMixedModel(main_graph.get(),
&scope,
&mixed_program_desc,
mixed_model_file,
mixed_params_file,
mixed_precision);
}
} // namespace analysis
......
......@@ -410,6 +410,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
pass_builder_->DeletePass(ps);
}
}
for (auto &delete_pass : other.pass_builder()->GetAllDeletedPasses()) {
pass_builder_->DeletePass(delete_pass);
}
}
void AnalysisConfig::EnableCUDNN() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册