未验证 提交 9bda10cd 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] fix mixed precision (#47794)

上级 78c8c7de
......@@ -111,7 +111,7 @@ class ConvertToMixedPrecisionPass {
black_list_(black_list),
place_(paddle::CPUPlace()),
executor_(place_) {
black_list_.insert("assign");
// black_list_.insert("assign");
black_list_.insert("fill_constant");
black_list_.insert("assign_value");
black_list_.insert("eye");
......@@ -416,8 +416,7 @@ void ConvertToMixedPrecisionPass::LoadAndPrepare() {
void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
std::unordered_set<std::string> all_var_names_set;
std::vector<std::unordered_set<std::string>> block_var_names_set(
program_desc_->Size());
std::vector<std::set<std::string>> block_var_names_set(program_desc_->Size());
for (BlockID idx = 0; idx < program_desc_->Size(); ++idx) {
for (auto* op : program_desc_->Block(idx).AllOps()) {
const auto& in_names = op->InputArgumentNames();
......@@ -449,7 +448,7 @@ void ConvertToMixedPrecisionPass::FindVarsInMultiBlock() {
for (const auto& name : vars_in_multi_block) {
vars_in_multi_block_with_pair_.emplace(
name, std::make_pair(VarType::FP32, idx));
name, std::make_pair(VarType::Type(), idx));
}
}
}
......@@ -612,7 +611,7 @@ void ConvertToMixedPrecisionPass::ConvertTensorDtype(BlockID block_idx) {
if (!has_float_input) {
support_precision = false;
VLOG(2) << " op doesn't has float input and output, just skip.";
VLOG(2) << " op doesn't has float input, just skip.";
}
}
VLOG(2) << "op type: " << op_type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册