convert_to_mixed_precision.cc 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"

#include "paddle/fluid/framework/executor.h"
18
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
19 20
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/inference/io.h"
21
#include "paddle/phi/common/backend.h"
22 23 24 25 26

namespace paddle {
namespace inference {
namespace analysis {

27 28 29 30 31 32
ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
    const std::string& model_file,
    const std::string& params_file,
    const std::string& mixed_model_file,
    const std::string& mixed_params_file,
    phi::DataType mixed_precision,
33
    phi::Backend backend,
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
    bool keep_io_types,
    const std::unordered_set<std::string>& black_list)
    : model_file_(model_file),
      params_file_(params_file),
      mixed_model_file_(mixed_model_file),
      mixed_params_file_(mixed_params_file),
      mixed_precision_(mixed_precision),
      backend_(backend),
      keep_io_types_(keep_io_types),
      black_list_(black_list) {
  if (mixed_precision_ != phi::DataType::FLOAT16 &&
      mixed_precision_ != phi::DataType::BFLOAT16) {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "mixed_precision currently not supported dtype %d, we now only "
        "support fp16 and bf16.",
        static_cast<int>(mixed_precision_)));
50
  }
51 52 53 54 55
  if (backend_ != phi::Backend::GPU) {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "mixed_precision currently not supported place %d, we now only "
        "support gpu.",
        static_cast<int>(backend_)));
56 57
  }
}
58

59 60
void ConvertToMixedPrecisionPass::LoadModel() {
  framework::Executor exe{platform::CPUPlace{}};
61

62
  auto program_desc = inference::Load(&exe, &scope_, model_file_, params_file_);
63
  main_graph_ = std::unique_ptr<framework::ir::Graph>(
64 65
      new framework::ir::Graph(*program_desc));
  main_graph_->SetNotOwned(framework::ir::kParamScopeAttr, &scope_);
66 67
}

68
void ConvertToMixedPrecisionPass::Run() {
69
  LoadModel();
70

71 72 73 74 75 76
  framework::ir::AutoMixedPrecisionPass pass;
  pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
  pass.Set("mixed_black_list",
           new std::unordered_set<std::string>{black_list_});
  pass.Set("enable_gpu_mixed", new bool{true});
  pass.Set("keep_io_types", new bool{keep_io_types_});
77

78
  pass.Apply(main_graph_.get());
W
Wilber 已提交
79

80
  SaveMixedModel();
81 82
}

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
void ConvertToMixedPrecisionPass::SaveMixedModel() {
  framework::ProgramDesc mixed_program_desc;
  framework::ir::GraphToProgram(*main_graph_, &mixed_program_desc);

  auto parameters = scope_.LocalVarNames();
  std::sort(parameters.begin(), parameters.end());

  auto SerializeParams = [&]() -> std::string {
    std::ostringstream os;
    phi::CPUContext ctx;
    for (const auto& param : parameters) {
      PADDLE_ENFORCE_NOT_NULL(
          scope_.FindVar(param),
          platform::errors::NotFound(
              "Block should already have a '%s' variable", param));
98
      auto* tensor = scope_.FindVar(param)->GetMutable<phi::DenseTensor>();
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
      framework::SerializeToStream(os, *tensor, ctx);
    }
    return os.str();
  };

  auto StrToBinary = [](const std::string& path, const std::string& str) {
    std::ofstream file(path.c_str(), std::ios::binary);
    file.write(str.c_str(), str.size());
    file.close();
  };

  StrToBinary(mixed_model_file_,
              mixed_program_desc.Proto()->SerializeAsString());
  StrToBinary(mixed_params_file_, SerializeParams());
}
114

115 116 117
bool OpSupportPrecision(const std::string& op_type,
                        phi::Backend backend,
                        phi::DataType precision,
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
                        const std::unordered_set<std::string>& black_list) {
  return framework::ir::OpSupportPrecision(
      op_type, backend, precision, black_list);
}

void InsertCastOp(
    framework::ir::Graph* graph,
    framework::ir::Node* var_node,
    framework::ir::Node* op_node,
    framework::proto::VarType::Type from_type,
    framework::proto::VarType::Type to_type,
    framework::BlockDesc* block_desc,
    int* suffix,
    std::unordered_map<framework::ir::Node*, framework::ir::Node*>* visited) {
  framework::ir::DoInsertCastOp(graph,
                                var_node,
                                op_node,
                                from_type,
                                to_type,
                                block_desc,
                                suffix,
                                visited);
140 141
}

142 143 144 145 146 147 148 149 150
void ConvertToMixedPrecision(
    const std::string& model_file,
    const std::string& params_file,
    const std::string& mixed_model_file,
    const std::string& mixed_params_file,
    phi::DataType mixed_precision,
    phi::Backend backend,
    bool keep_io_types,
    const std::unordered_set<std::string>& black_list) {
151 152 153 154 155 156 157 158 159
  ConvertToMixedPrecisionPass pass(model_file,
                                   params_file,
                                   mixed_model_file,
                                   mixed_params_file,
                                   mixed_precision,
                                   backend,
                                   keep_io_types,
                                   black_list);
  pass.Run();
160 161 162 163 164
}

}  // namespace analysis
}  // namespace inference
}  // namespace paddle