未验证 提交 b4a4eef2 编写于 作者: W Wilber 提交者: GitHub

convert support multi block. (#44866)

* convert support multi block.

* update
上级 f9e7fe66
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include <algorithm>
#include <iterator>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
...@@ -39,7 +42,106 @@ namespace analysis { ...@@ -39,7 +42,106 @@ namespace analysis {
namespace { namespace {
bool IsKernelSupportPrecision( inline std::string SerializeParams(framework::Scope* scope,
const std::vector<std::string>& params) {
std::ostringstream os;
phi::CPUContext ctx;
for (const auto& param : params) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(param),
platform::errors::NotFound("Block should already have a '%s' variable",
param));
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
}
inline void StrToBinary(const std::string& path, const std::string& str) {
std::ofstream file(path.c_str(), std::ios::binary);
file.write(str.c_str(), str.size());
file.close();
}
inline bool NodeVarHasDtype(framework::ir::Node* node) {
if (node->IsCtrlVar()) return false;
if (node->IsVar() &&
(node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB)) {
return true;
}
return false;
}
void SaveMixedModel(framework::ir::Graph* graph,
framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision) {
paddle::CPUPlace place;
auto parameters = scope->LocalVarNames();
std::sort(parameters.begin(), parameters.end());
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (NodeVarHasDtype(node)) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
for (const auto& param_name : parameters) {
auto* var = scope->FindLocalVar(param_name);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto* t = var->GetMutable<framework::LoDTensor>();
framework::Tensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
} else if (mixed_precision == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
}
}
}
StrToBinary(mixed_model_file,
mixed_program_desc->Proto()->SerializeAsString());
StrToBinary(mixed_params_file, SerializeParams(scope, parameters));
}
bool PhiKernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType data_type, phi::DataType data_type,
...@@ -56,10 +158,23 @@ bool GpuKernelSupportPrecision( ...@@ -56,10 +158,23 @@ bool GpuKernelSupportPrecision(
const std::string& op_type, const std::string& op_type,
phi::DataType data_type, phi::DataType data_type,
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) { phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT) {
bool res = auto phi_op_type = phi::TransToPhiKernelName(op_type);
IsKernelSupportPrecision(op_type, phi::Backend::GPU, data_type, layout); bool res = PhiKernelSupportPrecision(
res |= IsKernelSupportPrecision( phi_op_type, phi::Backend::GPU, data_type, layout);
op_type, phi::Backend::GPUDNN, data_type, layout); res |= PhiKernelSupportPrecision(
phi_op_type, phi::Backend::GPUDNN, data_type, layout);
if (!res) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
res = true;
}
}
}
}
return res; return res;
} }
...@@ -90,30 +205,16 @@ bool OutShouldNotConvert(ir::Node* var_node) { ...@@ -90,30 +205,16 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
void ProcessOutputNode(ir::Node* var_node,
// Get weight names which appear in multiple block (block 0 and block n). framework::proto::VarType::Type to_type) {
std::unordered_set<std::string> GetMultiBlockPersistableNames( if (!NodeVarHasDtype(var_node)) return;
framework::ProgramDesc* program_desc) { auto* out_var = var_node->Var();
std::unordered_set<std::string> special_weights; if (out_var->GetDataType() == framework::proto::VarType::FP32) {
size_t block_size = program_desc->Size(); if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
std::unordered_set<std::string> block_0_weights;
for (auto var : program_desc->Block(0).AllVars()) {
if (var->Persistable()) block_0_weights.insert(var->Name());
}
for (size_t i = 1; i < block_size; ++i) {
// std::cout << program_desc->MutableBlock(i)->Proto()->DebugString() <<
// std::endl;;
auto all_ops = program_desc->Block(i).AllOps();
for (auto op : all_ops) {
for (auto name : op->InputArgumentNames()) {
if (block_0_weights.count(name)) special_weights.insert(name);
}
}
} }
VLOG(3) << " out_node name " << var_node->Name() << " data_type "
return special_weights; << out_var->GetDataType();
} }
// Just process special cases for weights conversion. // Just process special cases for weights conversion.
...@@ -143,21 +244,8 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { ...@@ -143,21 +244,8 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
} }
} }
// If cur_op's next is condition_flow op, then cur op should be fp32. Note, we
// now only convert to mixed in block 0.
for (auto* op_node : op_nodes) {
for (auto var : op_node->outputs) {
for (auto next_op : var->outputs) {
if (next_op->Op()->HasAttr("sub_block")) {
return true;
}
}
}
}
return false; return false;
} }
inline bool IsFloatVarType(framework::proto::VarType::Type type) { inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 || if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 || type == framework::proto::VarType::FP32 ||
...@@ -165,6 +253,56 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) { ...@@ -165,6 +253,56 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) {
return true; return true;
return false; return false;
} }
void ProcessInputNode(
bool support_precision,
framework::ir::Graph* graph,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map,
framework::proto::VarType::Type to_type,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
if (!NodeVarHasDtype(in_node)) return;
auto* in_var = in_node->Var();
auto in_var_type = in_var->GetDataType();
if (!is_main_block && vars_in_multi_block_map->count(in_var->Name())) {
in_var_type = vars_in_multi_block_map->at(in_var->Name());
}
if (support_precision) {
if (in_var->Persistable() &&
in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type);
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
}
} else {
if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) {
AddCastOp(graph,
in_node,
op_node,
in_var_type,
to_type,
suffix,
block_desc,
cast_map);
}
}
VLOG(3) << " in_node name " << in_var->Name() << " data_type "
<< in_var->GetDataType();
}
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) { void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph); auto op_nodes = framework::ir::TopologySortOperations(*graph);
...@@ -239,6 +377,11 @@ void HandleSpecialOps(framework::OpDesc* op_desc) { ...@@ -239,6 +377,11 @@ void HandleSpecialOps(framework::OpDesc* op_desc) {
static_cast<int>(framework::proto::VarType::FP32)) static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype", op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16)); static_cast<int>(framework::proto::VarType::FP16));
} else if (op_desc->Type() == "fill_constant_batch_size_like") {
if (PADDLE_GET_CONST(int, op_desc->GetAttr("dtype")) ==
static_cast<int>(framework::proto::VarType::FP32))
op_desc->SetAttr("dtype",
static_cast<int>(framework::proto::VarType::FP16));
} }
} }
...@@ -260,26 +403,47 @@ void FixCastAttr(framework::ir::Graph* graph) { ...@@ -260,26 +403,47 @@ void FixCastAttr(framework::ir::Graph* graph) {
} }
} }
// If op's output var is condition flow op's input, then the op must be fp32 void FindVarsInMultiBlock(
// precision. framework::ProgramDesc* program_desc,
bool NextOpIncludesConditionFlowOp(framework::ir::Node* cur_op_node) { std::unordered_map<std::string, framework::proto::VarType::Type>*
auto cur_op_outs = cur_op_node->outputs; vars_in_multi_block_map) {
for (auto out_var : cur_op_outs) { std::set<std::string> vars_in_multi_block;
for (auto next_op_node : out_var->outputs) { std::set<std::string> main_block_var_names_set;
if (next_op_node->Op()->HasAttr("sub_block")) { for (auto op : program_desc->Block(0).AllOps()) {
return true; auto in_names = op->InputArgumentNames();
} main_block_var_names_set.insert(in_names.begin(), in_names.end());
}
for (size_t i = 1; i < program_desc->Size(); ++i) {
std::set<std::string> block_var_names_set;
for (auto op : program_desc->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set.insert(in_names.begin(), in_names.end());
} }
std::set_intersection(
main_block_var_names_set.begin(),
main_block_var_names_set.end(),
block_var_names_set.begin(),
block_var_names_set.end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
}
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(name, framework::proto::VarType::FP32);
} }
return false;
} }
void ConvertTensorDtype(framework::ProgramDesc* program_desc, void ConvertTensorDtype(
framework::ir::Graph* graph, framework::ProgramDesc* program_desc,
const std::unordered_set<std::string>& blacklist, framework::ir::Graph* graph,
bool keep_io_types, const std::unordered_set<std::string>& blacklist,
phi::Backend backend, bool keep_io_types,
phi::DataType tensor_dtype) { phi::Backend backend,
phi::DataType tensor_dtype,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
framework::proto::VarType::Type to_type; framework::proto::VarType::Type to_type;
if (tensor_dtype == phi::DataType::FLOAT16) { if (tensor_dtype == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16; to_type = framework::proto::VarType::FP16;
...@@ -287,25 +451,27 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, ...@@ -287,25 +451,27 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
to_type = framework::proto::VarType::BF16; to_type = framework::proto::VarType::BF16;
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only support " "mixed_precision currently not supported dtype %d, we now only "
"support "
"fp16 and bf16.", "fp16 and bf16.",
static_cast<int>(tensor_dtype))); static_cast<int>(tensor_dtype)));
} }
auto weight_name_in_multi_block = GetMultiBlockPersistableNames(program_desc); auto* block_desc =
framework::ir::TopologySortOperations(*graph)[0]->Op()->Block();
int num_low_precision = 0; int num_low_precision = 0;
int suffix = 0; int suffix = 0;
framework::BlockDesc* block_desc{nullptr};
std::vector<framework::ir::Node*> output_nodes; std::vector<framework::ir::Node*> output_nodes;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map; std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map;
auto op_nodes = framework::ir::TopologySortOperations(*graph); auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) { for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type); VLOG(3) << "-------------------- op_type " << op_type << ", phi_type "
<< phi::TransToPhiKernelName(op_type);
// 1. set input dtype. // 1. set input dtype.
if (op_type == "feed") { if (op_type == "feed") {
block_desc = op_node->Op()->Block();
auto feed_var = op_node->outputs[0]->Var(); auto feed_var = op_node->outputs[0]->Var();
if (!keep_io_types && if (!keep_io_types &&
feed_var->GetDataType() == framework::proto::VarType::FP32) { feed_var->GetDataType() == framework::proto::VarType::FP32) {
...@@ -319,71 +485,73 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, ...@@ -319,71 +485,73 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
continue; continue;
} }
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
if (NodeVarHasDtype(in)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
if (NodeVarHasDtype(out)) {
if (in_name_to_node.count(out->Name()))
out->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
continue;
}
// 2. if op support fp16/bf16 and not in blacklist. // 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16. // - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16. // - add cast op if the input dtype is not fp16/bf16.
// - set output dtype. // - set output dtype.
else if (blacklist.count(phi_op_type) == 0 && // NOLINT else if (blacklist.count(op_type) == 0) { // NOLINT
!NextOpIncludesConditionFlowOp(op_node)) {
bool support_precision = bool support_precision =
OpSupportPrecision(phi_op_type, backend, tensor_dtype, blacklist); OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "op_type " << op_type << ", phi_op_type " << phi_op_type VLOG(2) << "op_type " << op_type << ", phi_op_type "
<< " support low precision " << support_precision << ", " << phi::TransToPhiKernelName(op_type) << " support low precision "
<< support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block()); << reinterpret_cast<void*>(op_node->Op()->Block());
for (auto in_node : op_node->inputs) {
if (weight_name_in_multi_block.count(in_node->Name()))
support_precision = false;
}
if (support_precision) { if (support_precision) {
HandleSpecialOps(op_node->Op()); HandleSpecialOps(op_node->Op());
++num_low_precision; ++num_low_precision;
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
// Process inputs.
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue; ProcessInputNode(true,
auto* in_var = in_node->Var(); graph,
if (in_var->Persistable() && in_node,
in_var->GetDataType() == framework::proto::VarType::FP32) { op_node,
if (WeightsShouldNotConvert(in_node)) continue; &suffix,
in_var->SetDataType(to_type); block_desc,
} else if (!in_var->Persistable() && &cast_map,
IsFloatVarType(in_var->GetDataType()) && to_type,
in_var->GetDataType() != to_type) { is_main_block,
AddCastOp(graph, vars_in_multi_block_map);
in_node,
op_node,
in_var->GetDataType(),
to_type,
&suffix,
block_desc,
&cast_map);
}
} }
// Process outputs.
for (auto* out_node : op_node->outputs) { for (auto* out_node : op_node->outputs) {
if (out_node->IsCtrlVar()) continue; ProcessOutputNode(out_node, to_type);
auto* out_var = out_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(out_node)) continue;
out_var->SetDataType(to_type);
}
} }
} else { } else {
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue; ProcessInputNode(false,
auto* in_var = in_node->Var(); graph,
if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) && in_node,
in_var->GetDataType() != framework::proto::VarType::FP32) { op_node,
AddCastOp(graph, &suffix,
in_node, block_desc,
op_node, &cast_map,
in_var->GetDataType(), framework::proto::VarType::FP32,
framework::proto::VarType::FP32, is_main_block,
&suffix, vars_in_multi_block_map);
block_desc,
&cast_map);
}
} }
} }
} }
...@@ -409,8 +577,8 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, ...@@ -409,8 +577,8 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
} }
} }
// 4. if output_op's dtype is not compatible to output dtype, then just insert // 4. if output_op's dtype is not compatible to output dtype, then just
// cast. // insert cast.
for (auto* node : output_nodes) { for (auto* node : output_nodes) {
if (node->IsCtrlVar()) continue; if (node->IsCtrlVar()) continue;
auto var = node->Var(); auto var = node->Var();
...@@ -438,22 +606,31 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc, ...@@ -438,22 +606,31 @@ void ConvertTensorDtype(framework::ProgramDesc* program_desc,
} }
} }
if (is_main_block) {
for (auto node : graph->Nodes()) {
if (vars_in_multi_block_map->count(node->Name())) {
vars_in_multi_block_map->at(node->Name()) = node->Var()->GetDataType();
}
}
}
if (num_low_precision) if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision << " low precision ops"; LOG(INFO) << "--- detected " << num_low_precision << " low precision ops";
} }
} // namespace } // namespace
bool OpSupportPrecision(const std::string& phi_op_type, bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend, phi::Backend backend,
phi::DataType precision, phi::DataType precision,
const std::unordered_set<std::string>& blacklist) { const std::unordered_set<std::string>& blacklist) {
auto phi_op_type = phi::TransToPhiKernelName(op_type);
bool support_precision = false; bool support_precision = false;
if (blacklist.count(phi_op_type) == 0) { if (blacklist.count(op_type) == 0) {
if (backend == phi::Backend::GPU) if (backend == phi::Backend::GPU)
support_precision = GpuKernelSupportPrecision(phi_op_type, precision); support_precision = GpuKernelSupportPrecision(op_type, precision);
else else
support_precision = support_precision =
IsKernelSupportPrecision(phi_op_type, backend, precision); PhiKernelSupportPrecision(phi_op_type, backend, precision);
} }
return support_precision; return support_precision;
} }
...@@ -521,102 +698,41 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -521,102 +698,41 @@ void ConvertToMixedPrecision(const std::string& model_file,
framework::Scope scope; framework::Scope scope;
auto program_desc = auto program_desc =
inference::Load(&executor, &scope, model_file, params_file); inference::Load(&executor, &scope, model_file, params_file);
auto graph = std::unique_ptr<framework::ir::Graph>( auto main_graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc)); new framework::ir::Graph(*program_desc));
ConvertAllFp64ToFp32(graph.get()); std::unordered_map<std::string, framework::proto::VarType::Type>
ConvertTensorDtype(program_desc.get(), vars_in_multi_block_map;
graph.get(), FindVarsInMultiBlock(program_desc.get(), &vars_in_multi_block_map);
black_list,
keep_io_types, for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) {
backend, auto graph = main_graph->GetSubGraph(i);
mixed_precision); VLOG(2) << " -------- handle subgraph " << i << ", has "
FixCastAttr(graph.get()); << graph->Nodes().size() << " nodes";
framework::ProgramDesc mixed_program_desc; program_desc->Block(i).LocalVarNames();
framework::ir::GraphToProgram(*graph, &mixed_program_desc);
ConvertAllFp64ToFp32(graph);
auto parameters = scope.LocalVarNames(); ConvertTensorDtype(program_desc.get(),
std::sort(parameters.begin(), parameters.end()); graph,
black_list,
auto serialize_params = keep_io_types,
[](framework::Scope* scope, backend,
const std::vector<std::string>& params) -> std::string { mixed_precision,
std::ostringstream os; i == 0,
phi::CPUContext ctx; &vars_in_multi_block_map);
for (const auto& param : params) { FixCastAttr(graph);
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(param),
platform::errors::NotFound(
"Block should already have a '%s' variable", param));
auto* tensor = scope->FindVar(param)->GetMutable<framework::LoDTensor>();
framework::SerializeToStream(os, *tensor, ctx);
}
return os.str();
};
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
for (const auto& param_name : parameters) {
auto* var = scope.FindLocalVar(param_name);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto* t = var->GetMutable<framework::LoDTensor>();
framework::Tensor mixed_tensor;
mixed_tensor.Resize(t->dims());
auto* data = t->mutable_data<float>(platform::CPUPlace());
if (mixed_precision == phi::DataType::FLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::FLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<float16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
} else if (mixed_precision == phi::DataType::BFLOAT16 &&
!weights_should_be_fp32.count(param_name)) {
mixed_tensor.set_type(paddle::experimental::DataType::BFLOAT16);
auto* mixed_data =
mixed_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < t->numel(); i++) {
mixed_data[i] = static_cast<bfloat16>(data[i]);
}
t->clear();
paddle::framework::TensorCopySync(mixed_tensor, place, t);
}
}
} }
auto StrToBinary = [](const std::string& path, const std::string& str) { framework::ProgramDesc mixed_program_desc;
std::ofstream file(path.c_str(), std::ios::binary); framework::ir::GraphToProgram(*main_graph, &mixed_program_desc);
file.write(str.c_str(), str.size());
file.close(); SaveMixedModel(main_graph.get(),
}; &scope,
StrToBinary(mixed_model_file, &mixed_program_desc,
mixed_program_desc.Proto()->SerializeAsString()); mixed_model_file,
StrToBinary(mixed_params_file, serialize_params(&scope, parameters)); mixed_params_file,
mixed_precision);
} }
} // namespace analysis } // namespace analysis
......
...@@ -410,6 +410,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { ...@@ -410,6 +410,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
pass_builder_->DeletePass(ps); pass_builder_->DeletePass(ps);
} }
} }
for (auto &delete_pass : other.pass_builder()->GetAllDeletedPasses()) {
pass_builder_->DeletePass(delete_pass);
}
} }
void AnalysisConfig::EnableCUDNN() { void AnalysisConfig::EnableCUDNN() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册