未验证 提交 a00aebe1 编写于 作者: W Wilber 提交者: GitHub

[convert_to_mixed_precision] fallback to fp32 when encounter circle (#47902)

上级 d4d3d7ed
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/tensor_meta.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass { ...@@ -111,12 +110,10 @@ class ConvertToMixedPrecisionPass {
black_list_(black_list), black_list_(black_list),
place_(paddle::CPUPlace()), place_(paddle::CPUPlace()),
executor_(place_) { executor_(place_) {
// black_list_.insert("assign"); VLOG(4) << "black_list has ";
black_list_.insert("fill_constant"); for (auto& name : black_list_) {
black_list_.insert("assign_value"); VLOG(4) << " - " << name;
black_list_.insert("eye"); }
black_list_.insert("fill_any_like");
black_list_.insert("fill_constant_batch_size_like");
} }
void Run(); void Run();
...@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass { ...@@ -145,18 +142,11 @@ class ConvertToMixedPrecisionPass {
// Just process special cases for weights conversion. // Just process special cases for weights conversion.
bool WeightsShouldNotConvert(framework::ir::Node* var_node); bool WeightsShouldNotConvert(framework::ir::Node* var_node);
// To support multi block, we need to consider a lot of special cases.
// Return Node* which first appers in block. // Return Node* which first appers in block.
framework::ir::Node* GetRealVarNode(BlockID block_idx, framework::ir::Node* GetRealVarNode(framework::ir::Node* node);
framework::ir::Node* node);
void FindVarsInMultiBlock();
inline bool VarIsMultiPrecisionOpsOut(BlockID block_idx,
framework::ir::Node* op_node);
private: // Fallback to fp32 dtype when encounter circle (Not a DAG graph).
// A trick. Patch for strange op, which input name equal to output name, such void ProcessCircleCases();
// as `fused_multi_transformer`
void PatchForStrangeOp();
private: private:
std::string model_file_; std::string model_file_;
...@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass { ...@@ -171,35 +161,21 @@ class ConvertToMixedPrecisionPass {
framework::Executor executor_; framework::Executor executor_;
framework::Scope scope_; framework::Scope scope_;
std::unordered_map<std::string, framework::ir::Node*> name2node_;
std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_; std::unordered_map<framework::ir::Node*, framework::ir::Node*> cast_map_;
std::unordered_map<std::string, std::pair<VarType::Type, BlockID>>
vars_in_multi_block_with_pair_;
std::unordered_map<std::string, std::vector<std::string>>
vars_in_multi_block_with_ops_;
int suffix_{0}; int suffix_{0};
std::set<std::string> var_names_in_circles_;
std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr}; std::unique_ptr<framework::ProgramDesc> program_desc_{nullptr};
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr}; std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
std::vector<framework::ir::Graph*> graphes_; std::vector<framework::ir::Graph*> graphes_;
}; };
framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode( framework::ir::Node* ConvertToMixedPrecisionPass::GetRealVarNode(
BlockID block_idx, framework::ir::Node* var_node) { framework::ir::Node* var_node) {
CHECK_EQ(var_node->IsVar(), true); CHECK_EQ(var_node->IsVar(), true);
if (name2node_.count(var_node->Name())) return name2node_[var_node->Name()];
if (vars_in_multi_block_with_pair_.count(var_node->Name())) {
auto origin_blockId =
vars_in_multi_block_with_pair_.at(var_node->Name()).second;
if (block_idx != origin_blockId) {
auto* graph = graphes_[origin_blockId];
for (auto* node : graph->Nodes()) {
if (node->Name() == var_node->Name()) {
return node;
}
}
}
}
return var_node; return var_node;
} }
...@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype( ...@@ -212,32 +188,6 @@ inline bool ConvertToMixedPrecisionPass::VarNodeHasDtype(
(type == VarType::VOCAB); (type == VarType::VOCAB);
} }
// op1(fp32) -> var1, op2(fp16) -> var1
// if and only if op1 and op2 both support fp16, we convert op1 and op2's
// precision.
inline bool ConvertToMixedPrecisionPass::VarIsMultiPrecisionOpsOut(
BlockID block_idx, framework::ir::Node* op_node) {
CHECK_EQ(op_node->IsOp(), true);
for (auto* var_node : op_node->outputs) {
if (!var_node->IsVar()) continue;
auto* real_var_node = GetRealVarNode(block_idx, var_node);
if (!real_var_node->Var()->Persistable() &&
vars_in_multi_block_with_ops_.count(var_node->Name())) {
for (const auto& op_type :
vars_in_multi_block_with_ops_.at(var_node->Name())) {
if (!OpSupportPrecision(
op_type, backend_, mixed_precision_, black_list_)) {
VLOG(2) << var_node->Name()
<< " is multi precision op's out, so we skip convert to fp16";
return true;
}
}
}
}
return false;
}
void ConvertToMixedPrecisionPass::ProcessInputNode( void ConvertToMixedPrecisionPass::ProcessInputNode(
bool support_precision, bool support_precision,
framework::ir::Node* in_node, framework::ir::Node* in_node,
...@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( ...@@ -247,18 +197,13 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
VarType::Type to_type, VarType::Type to_type,
BlockID block_idx) { BlockID block_idx) {
if (!in_node->IsVar()) return; if (!in_node->IsVar()) return;
auto* real_node = GetRealVarNode(block_idx, in_node); auto* real_node = GetRealVarNode(in_node);
if (!VarNodeHasDtype(real_node)) return; if (!VarNodeHasDtype(real_node)) return;
auto* graph = graphes_[block_idx]; auto* graph = graphes_[block_idx];
bool is_main_block = block_idx == 0;
auto* in_var = real_node->Var(); auto* in_var = real_node->Var();
auto in_var_type = in_var->GetDataType(); auto in_var_type = in_var->GetDataType();
auto prev_type = in_var_type; auto prev_type = in_var_type;
bool is_in_multi_block = vars_in_multi_block_with_pair_.count(in_var->Name());
if (!is_main_block && is_in_multi_block) {
in_var_type = vars_in_multi_block_with_pair_.at(in_var->Name()).first;
}
if (support_precision) { if (support_precision) {
if (in_var->Persistable() && in_var_type == VarType::FP32) { if (in_var->Persistable() && in_var_type == VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return; if (WeightsShouldNotConvert(in_node)) return;
...@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode( ...@@ -299,7 +244,7 @@ void ConvertToMixedPrecisionPass::ProcessInputNode(
void ConvertToMixedPrecisionPass::ProcessOutputNode( void ConvertToMixedPrecisionPass::ProcessOutputNode(
BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) { BlockID block_idx, framework::ir::Node* var_node, VarType::Type to_type) {
if (!var_node->IsVar()) return; if (!var_node->IsVar()) return;
auto* real_node = GetRealVarNode(block_idx, var_node); auto* real_node = GetRealVarNode(var_node);
if (!VarNodeHasDtype(real_node)) return; if (!VarNodeHasDtype(real_node)) return;
auto* out_var = real_node->Var(); auto* out_var = real_node->Var();
auto prev_type = out_var->GetDataType(); auto prev_type = out_var->GetDataType();
...@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { ...@@ -400,9 +345,17 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
inference::Load(&executor_, &scope_, model_file_, params_file_); inference::Load(&executor_, &scope_, model_file_, params_file_);
main_graph_ = std::unique_ptr<framework::ir::Graph>( main_graph_ = std::unique_ptr<framework::ir::Graph>(
new framework::ir::Graph(*program_desc_)); new framework::ir::Graph(*program_desc_));
for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) { for (size_t i = 0; i < main_graph_->SubGraphsSize(); ++i) {
auto* graph = main_graph_->GetSubGraph(i); auto* graph = main_graph_->GetSubGraph(i);
graphes_.push_back(graph); graphes_.push_back(graph);
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
if (!name2node_.count(node->Name())) {
name2node_[node->Name()] = node;
}
}
} }
// Remove all control var // Remove all control var
...@@ -411,46 +364,68 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() { ...@@ -411,46 +364,68 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
arg.SetMainGraphNotOwned(main_graph_.get()); arg.SetMainGraphNotOwned(main_graph_.get());
pass.Run(&arg); pass.Run(&arg);
FindVarsInMultiBlock(); ProcessCircleCases();
} }
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() { // Find var names which in circles.
std::unordered_set<std::string> all_var_names_set; void ConvertToMixedPrecisionPass::ProcessCircleCases() {
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size()); std::vector<std::string> vars_in_circles;
for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) { for (size_t idx = 0; idx < program_desc_->Size(); ++idx) {
for (auto* op : program_desc_->Block(idx).AllOps()) { for (auto* op : program_desc_->Block(idx).AllOps()) {
// TODO(inference): batch_norm has circle, but we need to fuse it in conv
// op.
if (op->Type() == "batch_norm") continue;
const auto& in_names = op->InputArgumentNames(); const auto& in_names = op->InputArgumentNames();
block_var_names_set[idx].insert(in_names.begin(), in_names.end());
const auto& out_names = op->OutputArgumentNames(); const auto& out_names = op->OutputArgumentNames();
block_var_names_set[idx].insert(out_names.begin(), out_names.end()); std::set<std::string> in_names_set(in_names.begin(), in_names.end());
std::set<std::string> out_names_set(out_names.begin(), out_names.end());
if (op->HasAttr("sub_block") == false) { std::set_intersection(in_names_set.begin(),
for (const auto& name : out_names) { in_names_set.end(),
if (all_var_names_set.count(name)) { out_names_set.begin(),
vars_in_multi_block_with_ops_[name].push_back(op->Type()); out_names_set.end(),
} std::back_inserter(vars_in_circles));
}
}
all_var_names_set.insert(block_var_names_set[idx].begin(),
block_var_names_set[idx].end());
} }
} }
CHECK_GT(program_desc_->Size(), 0U); for (auto& name : vars_in_circles) {
for (BlockID idx = 0; idx < program_desc_->Size() - 1; ++idx) { var_names_in_circles_.insert(name);
for (BlockID jdx = idx + 1; jdx < program_desc_->Size(); ++jdx) { }
std::vector<std::string> vars_in_multi_block; for (auto& name : var_names_in_circles_) {
std::set_intersection(block_var_names_set[idx].begin(), LOG(INFO) << name
block_var_names_set[idx].end(), << " in circles, so we will skip process those vars and ops.";
block_var_names_set[jdx].begin(), }
block_var_names_set[jdx].end(), }
std::back_inserter(vars_in_multi_block));
inline void ProcessConstantOpAttr(framework::ir::Node* op_node,
for (const auto& name : vars_in_multi_block) { VarType::Type from_type,
vars_in_multi_block_with_pair_.emplace( VarType::Type to_type) {
name, std::make_pair(VarType::Type(), idx)); if (!op_node->IsOp()) return;
} auto op_type = op_node->Op()->Type();
} if (op_type == "feed" || op_type == "fetch") return;
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("dtype", static_cast<int>(to_type));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("in_dtype", static_cast<int>(to_type));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(from_type))
op_node->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
} }
} }
...@@ -460,33 +435,7 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32( ...@@ -460,33 +435,7 @@ void ConvertToMixedPrecisionPass::ConvertAllFp64ToFp32(
for (auto* op_node : op_nodes) { for (auto* op_node : op_nodes) {
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
if (op_type == "feed" || op_type == "fetch") continue; ProcessConstantOpAttr(op_node, VarType::FP64, VarType::FP32);
if (op_type == "fill_constant") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "assign_value") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "eye") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "fill_any_like") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("dtype", static_cast<int>(VarType::FP32));
} else if (op_type == "cast") {
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("in_dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("in_dtype", static_cast<int>(VarType::FP32));
if (PADDLE_GET_CONST(int, op_node->Op()->GetAttr("out_dtype")) ==
static_cast<int>(VarType::FP64))
op_node->Op()->SetAttr("out_dtype", static_cast<int>(VarType::FP32));
}
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
...@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() { ...@@ -509,9 +458,6 @@ void ConvertToMixedPrecisionPass::Run() {
ConvertTensorDtype(i); ConvertTensorDtype(i);
FixCastAttr(graph); FixCastAttr(graph);
// A trick
PatchForStrangeOp();
CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true); CHECK_EQ(framework::ir::VarDescIsConsistency(*graph), true);
} }
...@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -556,28 +502,9 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
continue; continue;
} }
// We can not add cast operator before ops who have sub_block, as in
// sub_block we may get a var which may be transformer by cast op.
else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT else if (op_node->Op()->HasAttr("sub_block")) { // NOLINT
// sub_block op's output dtype should be same as input dtype, if have the
// same name.
std::unordered_map<std::string, framework::ir::Node*> in_name_to_node;
for (auto* in : op_node->inputs) {
if (!in->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in);
if (VarNodeHasDtype(real_node)) {
in_name_to_node[in->Name()] = in;
}
}
for (auto* out : op_node->outputs) {
if (!out->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, out);
if (VarNodeHasDtype(real_node)) {
if (in_name_to_node.count(out->Name()))
real_node->Var()->SetDataType(
in_name_to_node[out->Name()]->Var()->GetDataType());
}
}
continue; continue;
} }
...@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -585,65 +512,75 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// - cast weight to fp16/bf16. // - cast weight to fp16/bf16.
// - add cast op if the input dtype is not fp16/bf16. // - add cast op if the input dtype is not fp16/bf16.
// - set output dtype. // - set output dtype.
// else if (black_list_.count(op_type) == 0) { // NOLINT
// If a var(op's out var) appears multiple times in graph, we should not
// convert to fp16.
else if (black_list_.count(op_type) == 0 && // NOLINT
!VarIsMultiPrecisionOpsOut(block_idx, op_node)) {
bool support_precision = bool support_precision =
OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_); OpSupportPrecision(op_type, backend_, mixed_precision_, black_list_);
// If the op has no input of float type, we will not choose the // If op's output in circle, we should not convert to fp16.
for (auto* out_node : op_node->outputs) {
if (var_names_in_circles_.count(out_node->Name())) {
support_precision = false;
VLOG(2) << " op's output " << out_node->Name()
<< " is in circle, we can not support this case, just skip.";
break;
}
}
// If the op has no input or output of float type, we will not choose the
// low precision kernel. // low precision kernel.
{ if (support_precision) {
bool has_float_input{false}; bool has_float_in_out{false};
for (auto* in_node : op_node->inputs) { for (auto* in_node : op_node->inputs) {
if (!in_node->IsVar()) continue; if (!in_node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, in_node); if (in_node->Var()->GetType() != VarType::LOD_TENSOR) {
support_precision = false;
VLOG(2) << " op has tensor array input[" << in_node->Name()
<< "], just skip.";
break;
}
auto* real_node = GetRealVarNode(in_node);
if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) {
has_float_in_out = true;
break;
}
}
for (auto* out_node : op_node->outputs) {
if (!out_node->IsVar()) continue;
auto* real_node = GetRealVarNode(out_node);
if (real_node->Var()->GetDataType() == VarType::FP16 || if (real_node->Var()->GetDataType() == VarType::FP16 ||
real_node->Var()->GetDataType() == VarType::FP32 || real_node->Var()->GetDataType() == VarType::FP32 ||
real_node->Var()->GetDataType() == VarType::FP64 || real_node->Var()->GetDataType() == VarType::FP64 ||
real_node->Var()->GetDataType() == VarType::BF16) { real_node->Var()->GetDataType() == VarType::BF16) {
has_float_input = true; has_float_in_out = true;
break; break;
} }
} }
if (!has_float_input) { if (!has_float_in_out) {
support_precision = false; support_precision = false;
VLOG(2) << " op doesn't has float input, just skip."; VLOG(2) << " op doesn't has float input and output, just skip.";
} }
} }
VLOG(2) << "op type: " << op_type VLOG(2) << "op type: " << op_type
<< " support low precision: " << support_precision; << " support low precision: " << support_precision;
if (support_precision) { if (support_precision) {
ProcessConstantOpAttr(op_node, VarType::FP32, to_type);
VLOG(2) << " process input nodes:"; VLOG(2) << " process input nodes:";
++num_low_precision; ++num_low_precision;
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
// Just for paddle's terriable case: op's input and output has the same
// name.
std::unordered_map<std::string, std::string> names_map;
for (auto* out_node : op_node->outputs) {
for (auto* in_node : op_node->inputs) {
if (out_node->Name() == in_node->Name()) {
names_map[out_node->Name()] = in_node->Name();
}
}
}
// Process inputs.
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
ProcessInputNode( ProcessInputNode(
true, in_node, op_node, &suffix_, block_desc, to_type, block_idx); true, in_node, op_node, &suffix_, block_desc, to_type, block_idx);
if (names_map.count(in_node->Name()) && cast_map_.count(in_node)) {
names_map[in_node->Name()] = cast_map_[in_node]->Name();
}
} }
VLOG(2) << " process output nodes:"; VLOG(2) << " process output nodes:";
// Process outputs. auto outputs = op_node->outputs;
for (auto* out_node : op_node->outputs) { for (auto* out_node : outputs) {
ProcessOutputNode(block_idx, out_node, to_type); ProcessOutputNode(block_idx, out_node, to_type);
} }
} else { } else {
...@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -663,8 +600,10 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// 3. check op not support fp16/bf16 or in blacklist. // 3. check op not support fp16/bf16 or in blacklist.
// - add cast op if the input dtype is not fp32. // - add cast op if the input dtype is not fp32.
else { // NOLINT else { // NOLINT
VLOG(3) << "not to run fp16 op_type: " << op_type; VLOG(3) << "not to run fp16 op_type: " << op_type << ", node input size "
for (auto* in_node : op_node->inputs) { << op_node->inputs.size();
auto in_nodes = op_node->inputs;
for (auto* in_node : in_nodes) {
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) { if (in_var->GetDataType() == to_type) {
AddCastOp(graph, AddCastOp(graph,
...@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -716,21 +655,6 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
} }
} }
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
auto* real_node = GetRealVarNode(block_idx, node);
if (!VarNodeHasDtype(real_node)) continue;
if (vars_in_multi_block_with_pair_.count(real_node->Name()) &&
vars_in_multi_block_with_pair_.at(real_node->Name()).second ==
block_idx &&
vars_in_multi_block_with_pair_.at(real_node->Name()).first ==
VarType::Type()) {
vars_in_multi_block_with_pair_.at(real_node->Name()).first =
real_node->Var()->GetDataType();
}
}
if (num_low_precision) if (num_low_precision)
LOG(INFO) << "--- detected " << num_low_precision LOG(INFO) << "--- detected " << num_low_precision
<< " low precision ops in " << block_idx << " subgraph"; << " low precision ops in " << block_idx << " subgraph";
...@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) { ...@@ -738,6 +662,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
// We modify op's input output precision, and we need to fix cast op in_dtype // We modify op's input output precision, and we need to fix cast op in_dtype
// and out_dtype attribute. // and out_dtype attribute.
// TODO(inference): we need a cast elimination pass.
void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) { void ConvertToMixedPrecisionPass::FixCastAttr(framework::ir::Graph* graph) {
auto op_nodes = framework::ir::TopologySortOperations(*graph); auto op_nodes = framework::ir::TopologySortOperations(*graph);
for (auto* op_node : op_nodes) { for (auto* op_node : op_nodes) {
...@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -766,7 +691,8 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
if (VarNodeHasDtype(node)) { if (VarNodeHasDtype(node)) {
if (node->Var()->Persistable() && if (node->Var()->Persistable() &&
node->Var()->GetDataType() == VarType::FP32) { node->Var()->GetDataType() == VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name(); VLOG(2) << "weights keep to fp32: " << node->Name() << ", ptr "
<< reinterpret_cast<void*>(node->Var());
weights_should_be_fp32.insert(node->Name()); weights_should_be_fp32.insert(node->Name());
} }
} }
...@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -808,7 +734,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
std::ostringstream os; std::ostringstream os;
phi::CPUContext ctx; phi::CPUContext ctx;
for (const auto& param : parameters) { for (const auto& param : parameters) {
VLOG(3) << "Serialize param: " << param;
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope_.FindVar(param), scope_.FindVar(param),
platform::errors::NotFound( platform::errors::NotFound(
...@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() { ...@@ -829,21 +754,6 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
mixed_program_desc.Proto()->SerializeAsString()); mixed_program_desc.Proto()->SerializeAsString());
StrToBinary(mixed_params_file_, SerializeParams()); StrToBinary(mixed_params_file_, SerializeParams());
} }
void ConvertToMixedPrecisionPass::PatchForStrangeOp() {
for (auto* graph : graphes_) {
for (auto op_node : framework::ir::TopologySortOperations(*graph)) {
if (op_node->Name() == "fused_multi_transformer") {
auto cache_kv_inputs = op_node->Op()->Input("CacheKV");
auto cache_kv_outputs = op_node->Op()->Output("CacheKVOut");
CHECK_EQ(cache_kv_inputs.size(), cache_kv_outputs.size());
for (size_t i = 0; i < cache_kv_inputs.size(); ++i) {
op_node->Op()->RenameOutput(cache_kv_outputs[i], cache_kv_inputs[i]);
}
}
}
}
}
} // namespace } // namespace
void AddCastOp( void AddCastOp(
...@@ -893,6 +803,7 @@ void AddCastOp( ...@@ -893,6 +803,7 @@ void AddCastOp(
} }
next_op->Op()->Rename(node->Name(), map->at(node)->Name()); next_op->Op()->Rename(node->Name(), map->at(node)->Name());
IR_NODE_LINK_TO(node, map->at(node)->inputs[0]); IR_NODE_LINK_TO(node, map->at(node)->inputs[0]);
IR_NODE_UNLINK(node, next_op);
IR_NODE_LINK_TO(map->at(node), next_op); IR_NODE_LINK_TO(map->at(node), next_op);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册