未验证 提交 9aecf286 编写于 作者: W Wilber 提交者: GitHub

convert_fp16 support multi block (#45050)

* convert_fp16 support multi block

* update

* update
上级 b0e7681f
......@@ -38,6 +38,7 @@ build_doc/
CMakeSettings.json
Makefile
.test_env/
.cache/
third_party/
*~
......
......@@ -19,6 +19,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h"
......@@ -29,6 +30,7 @@
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/layout.h"
......@@ -63,6 +65,7 @@ inline void StrToBinary(const std::string& path, const std::string& str) {
file.write(str.c_str(), str.size());
file.close();
}
inline bool NodeVarHasDtype(framework::ir::Node* node) {
if (node->IsCtrlVar()) return false;
......@@ -80,12 +83,63 @@ inline bool NodeVarHasDtype(framework::ir::Node* node) {
return false;
}
void SaveMixedModel(framework::ir::Graph* graph,
framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision) {
// Return Node* which first appers in block.
framework::ir::Node* GetRealNode(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
if (vars_in_multi_block_map->count(node->Name())) {
int var_origin_block_id = vars_in_multi_block_map->at(node->Name()).second;
if (block_idx != var_origin_block_id) {
auto graph = graphes[var_origin_block_id];
for (auto nd : graph->Nodes()) {
if (nd->Name() == node->Name()) {
return nd;
}
}
}
}
return node;
}
inline bool VarIsMultiOpsOut(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
CHECK_EQ(op_node->IsOp(), true);
for (auto* out : op_node->outputs) {
if (out->IsCtrlVar()) continue;
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!real_node->Var()->Persistable() &&
vars_appear_multi_in_one_block[block_idx].count(out->Name())) {
VLOG(2) << out->Name()
<< " is multi op's out, so we skip convert to fp16";
return true;
}
}
return false;
}
void SaveMixedModel(
framework::ir::Graph* graph,
framework::Scope* scope,
framework::ProgramDesc* mixed_program_desc,
const std::string& mixed_model_file,
const std::string& mixed_params_file,
phi::DataType mixed_precision,
const std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>&
vars_in_multi_block_map) {
paddle::CPUPlace place;
auto parameters = scope->LocalVarNames();
std::sort(parameters.begin(), parameters.end());
......@@ -169,7 +223,8 @@ bool GpuKernelSupportPrecision(
auto it = all_kernels.find(op_type);
if (it != all_kernels.end()) {
for (auto& kern_pair : it->second) {
if (platform::is_gpu_place(kern_pair.first.place_)) {
if (platform::is_gpu_place(kern_pair.first.place_) &&
kern_pair.first.data_type_ == framework::proto::VarType::FP16) {
res = true;
}
}
......@@ -205,10 +260,18 @@ bool OutShouldNotConvert(ir::Node* var_node) {
return false;
}
void ProcessOutputNode(ir::Node* var_node,
framework::proto::VarType::Type to_type) {
if (!NodeVarHasDtype(var_node)) return;
auto* out_var = var_node->Var();
void ProcessOutputNode(
const std::vector<framework::ir::Graph*>& graphes,
int block_idx,
ir::Node* var_node,
framework::proto::VarType::Type to_type,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
auto* real_node =
GetRealNode(graphes, block_idx, var_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto* out_var = real_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(var_node)) return;
out_var->SetDataType(to_type);
......@@ -241,6 +304,26 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
} else if (op_desc->Type() == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnScale");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
vecs = op_desc->Input("FFNLnBias");
if (std::find(vecs.begin(), vecs.end(), var_node->Name()) != vecs.end()) {
return true;
}
}
}
......@@ -255,21 +338,28 @@ inline bool IsFloatVarType(framework::proto::VarType::Type type) {
}
void ProcessInputNode(
bool support_precision,
framework::ir::Graph* graph,
std::vector<framework::ir::Graph*> graphes,
ir::Node* in_node,
ir::Node* op_node,
int* suffix,
framework::BlockDesc* block_desc,
std::unordered_map<framework::ir::Node*, framework::ir::Node*>* cast_map,
framework::proto::VarType::Type to_type,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
int block_idx,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
if (!NodeVarHasDtype(in_node)) return;
auto* in_var = in_node->Var();
auto* real_node =
GetRealNode(graphes, block_idx, in_node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) return;
auto graph = graphes[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType();
if (!is_main_block && vars_in_multi_block_map->count(in_var->Name())) {
in_var_type = vars_in_multi_block_map->at(in_var->Name());
bool is_in_multi_block = vars_in_multi_block_map->count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_map->at(in_var->Name()).first;
}
if (support_precision) {
if (in_var->Persistable() &&
......@@ -300,8 +390,7 @@ void ProcessInputNode(
cast_map);
}
}
VLOG(3) << " in_node name " << in_var->Name() << " data_type "
<< in_var->GetDataType();
VLOG(3) << " in_node name " << in_var->Name() << " data_type " << in_var_type;
}
void ConvertAllFp64ToFp32(framework::ir::Graph* graph) {
......@@ -405,45 +494,87 @@ void FixCastAttr(framework::ir::Graph* graph) {
void FindVarsInMultiBlock(
framework::ProgramDesc* program_desc,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
std::set<std::string> vars_in_multi_block;
std::set<std::string> main_block_var_names_set;
for (auto op : program_desc->Block(0).AllOps()) {
auto in_names = op->InputArgumentNames();
main_block_var_names_set.insert(in_names.begin(), in_names.end());
}
for (size_t i = 1; i < program_desc->Size(); ++i) {
std::set<std::string> block_var_names_set;
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
std::vector<std::set<std::string>>* vars_appear_multi_in_one_block) {
std::vector<std::set<std::string>> block_var_names_set(program_desc->Size());
for (size_t i = 0; i < program_desc->Size(); ++i) {
for (auto op : program_desc->Block(i).AllOps()) {
auto in_names = op->InputArgumentNames();
block_var_names_set.insert(in_names.begin(), in_names.end());
block_var_names_set[i].insert(in_names.begin(), in_names.end());
auto out_names = op->OutputArgumentNames();
if (op->HasAttr("sub_block") == false) {
for (auto& n : out_names) {
if (block_var_names_set[i].count(n)) {
(*vars_appear_multi_in_one_block)[i].insert(n);
}
}
}
block_var_names_set[i].insert(out_names.begin(), out_names.end());
}
}
for (size_t i = 0; i < program_desc->Size() - 1; ++i) {
for (size_t j = i + 1; j < program_desc->Size(); ++j) {
std::set<std::string> vars_in_multi_block;
std::set_intersection(
block_var_names_set[i].begin(),
block_var_names_set[i].end(),
block_var_names_set[j].begin(),
block_var_names_set[j].end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(
name, std::make_pair(framework::proto::VarType::FP32, i));
}
}
}
}
std::set_intersection(
main_block_var_names_set.begin(),
main_block_var_names_set.end(),
block_var_names_set.begin(),
block_var_names_set.end(),
std::inserter(vars_in_multi_block, vars_in_multi_block.begin()));
bool OpInOutHasTensorArray(
std::vector<framework::ir::Graph*> graphes,
int block_idx,
framework::ir::Node* op_node,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map) {
CHECK_EQ(op_node->IsOp(), true);
for (auto in : op_node->inputs) {
auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
}
for (auto name : vars_in_multi_block) {
vars_in_multi_block_map->emplace(name, framework::proto::VarType::FP32);
for (auto out : op_node->outputs) {
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (real_node->Var()->GetType() ==
framework::proto::VarType::LOD_TENSOR_ARRAY)
return true;
}
return false;
}
void ConvertTensorDtype(
framework::ProgramDesc* program_desc,
framework::ir::Graph* graph,
std::vector<framework::ir::Graph*> graphes,
const std::unordered_set<std::string>& blacklist,
bool keep_io_types,
phi::Backend backend,
phi::DataType tensor_dtype,
bool is_main_block,
std::unordered_map<std::string, framework::proto::VarType::Type>*
vars_in_multi_block_map) {
int block_idx,
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>*
vars_in_multi_block_map,
const std::vector<std::set<std::string>>& vars_appear_multi_in_one_block) {
auto graph = graphes[block_idx];
framework::proto::VarType::Type to_type;
if (tensor_dtype == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
......@@ -452,8 +583,7 @@ void ConvertTensorDtype(
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only "
"support "
"fp16 and bf16.",
"support fp16 and bf16.",
static_cast<int>(tensor_dtype)));
}
......@@ -490,15 +620,19 @@ void ConvertTensorDtype(
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
if (NodeVarHasDtype(in)) {
auto* real_node =
GetRealNode(graphes, block_idx, in, vars_in_multi_block_map);
if (NodeVarHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto out : op_node->outputs) {
if (NodeVarHasDtype(out)) {
auto* real_node =
GetRealNode(graphes, block_idx, out, vars_in_multi_block_map);
if (NodeVarHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
out->Var()->SetDataType(
real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
......@@ -506,17 +640,39 @@ void ConvertTensorDtype(
continue;
}
// A strange case found in multi block.
else if (op_type == "assign" && // NOLINT
op_node->inputs[0]->Name() == op_node->outputs[0]->Name()) {
VLOG(2) << " in out are same, continue";
continue;
}
// Handle tensor array.
else if (OpInOutHasTensorArray( // NOLINT
graphes,
block_idx,
op_node,
vars_in_multi_block_map)) {
VLOG(2) << " in or out has tensor array, continue";
continue;
}
// 2. if op support fp16/bf16 and not in blacklist.
// - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16.
// - set output dtype.
else if (blacklist.count(op_type) == 0) { // NOLINT
//
// If a var(op's out var) appears multiple times in a block, we should not
// convert to fp16.
else if (blacklist.count(op_type) == 0 && // NOLINT
!VarIsMultiOpsOut(graphes,
block_idx,
op_node,
vars_in_multi_block_map,
vars_appear_multi_in_one_block)) {
bool support_precision =
OpSupportPrecision(op_type, backend, tensor_dtype, blacklist);
VLOG(2) << "op_type " << op_type << ", phi_op_type "
<< phi::TransToPhiKernelName(op_type) << " support low precision "
<< support_precision << ", "
<< reinterpret_cast<void*>(op_node->Op()->Block());
VLOG(2) << " support low precision " << support_precision;
if (support_precision) {
HandleSpecialOps(op_node->Op());
......@@ -525,32 +681,33 @@ void ConvertTensorDtype(
// Process inputs.
for (auto* in_node : inputs) {
ProcessInputNode(true,
graph,
graphes,
in_node,
op_node,
&suffix,
block_desc,
&cast_map,
to_type,
is_main_block,
block_idx,
vars_in_multi_block_map);
}
// Process outputs.
for (auto* out_node : op_node->outputs) {
ProcessOutputNode(out_node, to_type);
ProcessOutputNode(
graphes, block_idx, out_node, to_type, vars_in_multi_block_map);
}
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
ProcessInputNode(false,
graph,
graphes,
in_node,
op_node,
&suffix,
block_desc,
&cast_map,
framework::proto::VarType::FP32,
is_main_block,
block_idx,
vars_in_multi_block_map);
}
}
......@@ -606,16 +763,21 @@ void ConvertTensorDtype(
}
}
if (is_main_block) {
for (auto node : graph->Nodes()) {
if (vars_in_multi_block_map->count(node->Name())) {
vars_in_multi_block_map->at(node->Name()) = node->Var()->GetDataType();
}
for (auto node : graph->Nodes()) {
auto* real_node =
GetRealNode(graphes, block_idx, node, vars_in_multi_block_map);
if (!NodeVarHasDtype(real_node)) continue;
if (vars_in_multi_block_map->count(real_node->Name()) &&
vars_in_multi_block_map->at(real_node->Name()).second == block_idx) {
vars_in_multi_block_map->at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision << " low precision ops";
LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph";
}
} // namespace
......@@ -701,26 +863,32 @@ void ConvertToMixedPrecision(const std::string& model_file,
auto main_graph = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc));
std::unordered_map<std::string, framework::proto::VarType::Type>
std::unordered_map<std::string,
std::pair<framework::proto::VarType::Type, int>>
vars_in_multi_block_map;
FindVarsInMultiBlock(program_desc.get(), &vars_in_multi_block_map);
std::vector<std::set<std::string>> vars_appear_multi_in_one_block(
program_desc->Size());
FindVarsInMultiBlock(program_desc.get(),
&vars_in_multi_block_map,
&vars_appear_multi_in_one_block);
std::vector<framework::ir::Graph*> graphes;
for (size_t i = 0; i < main_graph->SubGraphsSize(); ++i) {
auto graph = main_graph->GetSubGraph(i);
graphes.push_back(graph);
VLOG(2) << " -------- handle subgraph " << i << ", has "
<< graph->Nodes().size() << " nodes";
program_desc->Block(i).LocalVarNames();
<< graph->Nodes().size() << " nodes --------";
ConvertAllFp64ToFp32(graph);
ConvertTensorDtype(program_desc.get(),
graph,
graphes,
black_list,
keep_io_types,
backend,
mixed_precision,
i == 0,
&vars_in_multi_block_map);
i,
&vars_in_multi_block_map,
vars_appear_multi_in_one_block);
FixCastAttr(graph);
}
......@@ -732,7 +900,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
&mixed_program_desc,
mixed_model_file,
mixed_params_file,
mixed_precision);
mixed_precision,
vars_in_multi_block_map);
}
} // namespace analysis
......
......@@ -438,15 +438,14 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_output_desc,
algo,
&workspace_size_in_bytes));
PADDLE_ENFORCE_LE(
workspace_size_in_bytes,
workspace_size_limit,
platform::errors::InvalidArgument(
"The actual workspace size to be allocated for cuDNN is expected "
"to be less than the limit. But received: the actual workspace "
"size = %d, limit = %d.",
workspace_size_in_bytes,
workspace_size_limit));
// PADDLE_ENFORCE_LE(
// workspace_size_in_bytes,
// workspace_size_limit,
// platform::errors::InvalidArgument(
// "The actual workspace size to be allocated for cuDNN is expected
// " "to be less than the limit. But received: the actual workspace
// " "size = %d, limit = %d.", workspace_size_in_bytes,
// workspace_size_limit));
if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册