mkldnn_pass_util.h 6.2 KB
Newer Older
B
baoachun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
18

B
baoachun 已提交
19 20 21 22 23 24
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace ir {

25 26 27
using StringPairMap =
    std::unordered_map<std::string, std::pair<bool, phi::DenseTensor>>;

28
static void SaveInfoInTheTmpOp(
29 30 31
    ir::Graph* graph,
    const std::string& flag,
    const std::string& key_suffix,
B
baoachun 已提交
32 33 34 35
    const std::unordered_map<std::string, std::vector<float>>& info_map) {
  VLOG(3) << "save variables in the first op's attr";

  const std::string suffix = "_" + key_suffix + "_" + flag;
36 37 38 39 40 41 42
  OpDesc op_desc;
  op_desc.SetType("save");
  auto* op_node = graph->CreateOpNode(&op_desc);

  op_node->Op()->SetAttr(flag, true);
  for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
    op_node->Op()->SetAttr(iter->first + suffix, iter->second);
B
baoachun 已提交
43 44 45
  }
}

46 47 48 49
static void SaveInfoInTheTmpOp(ir::Graph* graph,
                               const std::string& flag,
                               const std::string& key_suffix,
                               const StringPairMap& info_map) {
50 51 52
  VLOG(3) << "save variables in the first op's attr";

  const std::string suffix = "_" + key_suffix + "_" + flag;
53 54 55 56 57 58 59 60 61 62 63 64

  OpDesc op_desc;
  op_desc.SetType("save");
  auto* op_node = graph->CreateOpNode(&op_desc);

  op_node->Op()->SetAttr(flag, true);
  for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) {
    auto* data = iter->second.second.data<float>();
    std::vector<float> data_v(data, data + iter->second.second.numel());
    op_node->Op()->SetAttr(iter->first + suffix + "_unsigned",
                           iter->second.first);
    op_node->Op()->SetAttr(iter->first + suffix, data_v);
65 66 67
  }
}

68
static void GetInfoFromTheTmpOp(
69 70 71
    ir::Graph* graph,
    const std::string& flag,
    const std::string& key_suffix,
B
baoachun 已提交
72 73 74 75 76 77
    std::unordered_map<std::string, std::vector<float>>* info_map) {
  VLOG(3) << "get variables from the first op's attr";

  const std::string suffix = "_" + key_suffix + "_" + flag;
  for (auto* op_node :
       ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
78
    if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
B
baoachun 已提交
79 80 81 82 83 84 85 86 87 88

    auto* op_desc = op_node->Op();
    if (op_desc->GetAttrIfExists<bool>(flag)) {
      op_desc->RemoveAttr(flag);
      std::vector<std::string> attr_names = op_desc->AttrNames();
      for (auto fake_name : attr_names) {
        size_t pos = fake_name.find(suffix);
        if (pos != std::string::npos) {
          std::string name = fake_name.substr(0, pos);
          auto scales_vector =
R
Ruibiao Chen 已提交
89
              PADDLE_GET_CONST(std::vector<float>, op_desc->GetAttr(fake_name));
B
baoachun 已提交
90 91 92 93
          info_map->insert(std::make_pair(name, scales_vector));
          op_desc->RemoveAttr(fake_name);
        }
      }
94
      graph->RemoveNode(op_node);
B
baoachun 已提交
95 96 97 98 99
      break;
    }
  }
}

100 101 102 103
static void GetInfoFromTheTmpOp(ir::Graph* graph,
                                const std::string& flag,
                                const std::string& key_suffix,
                                StringPairMap* info_map) {
104 105 106 107 108 109
  VLOG(3) << "get variables from the first op's attr";
  const std::string unsigned_flag = "_unsigned";
  const std::string suffix = "_" + key_suffix + "_" + flag;
  const std::string suffix_is_unsigned = suffix + unsigned_flag;
  for (auto* op_node :
       ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
110
    if (!op_node->IsOp() || op_node->Op()->Type() != "save") continue;
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

    auto* op_desc = op_node->Op();
    if (op_desc->GetAttrIfExists<bool>(flag)) {
      op_desc->RemoveAttr(flag);
      std::vector<std::string> attr_names = op_desc->AttrNames();
      for (auto fake_name : attr_names) {
        auto is_unsigned = false;
        size_t pos = fake_name.find(suffix_is_unsigned);

        if (pos != std::string::npos) {
          std::string unsigned_var_name = fake_name;
          is_unsigned =
              PADDLE_GET_CONST(bool, op_desc->GetAttr(unsigned_var_name));

          std::string var_name = fake_name.substr(0, pos);
          size_t unsigned_pos = fake_name.find(unsigned_flag);
          std::string vector_name =
              fake_name.erase(unsigned_pos, unsigned_flag.length());
          auto scales_vector = PADDLE_GET_CONST(std::vector<float>,
                                                op_desc->GetAttr(vector_name));
          phi::DenseTensor tensor;
          const int size = static_cast<int>(scales_vector.size());
133
          auto data = tensor.mutable_data<double>({size}, phi::CPUPlace());
134 135 136 137 138 139 140
          std::copy(scales_vector.begin(), scales_vector.end(), data);
          auto pair = std::make_pair(is_unsigned, tensor);
          info_map->insert(std::make_pair(var_name, pair));
          op_desc->RemoveAttr(unsigned_var_name);
          op_desc->RemoveAttr(vector_name);
        }
      }
141
      graph->RemoveNode(op_node);
142 143 144 145 146
      break;
    }
  }
}

S
Sławomir Siwek 已提交
147 148 149 150 151
inline void ConvertToFusedOp(OpDesc* op) {
  const std::map<std::string, std::string> fused_ops = {
      {"conv2d", "fused_conv2d"},
      {"depthwise_conv2d", "fused_conv2d"},
      {"matmul", "fused_matmul"},
S
Sławomir Siwek 已提交
152
      {"matmul_v2", "fused_matmul"},
153 154
      {"softplus", "fused_softplus"},
      {"transpose2", "fused_transpose"}};
S
Sławomir Siwek 已提交
155 156 157 158 159 160 161 162 163 164 165

  if (op->Type() == "matmul") {
    op->SetAttr("trans_x", op->GetAttr("transpose_X"));
    op->SetAttr("trans_y", op->GetAttr("transpose_Y"));
    op->SetAttr("matmul_alpha", op->GetAttr("alpha"));
  }

  auto it = fused_ops.find(op->Type());
  if (it != fused_ops.end()) {
    op->SetType(it->second);
    VLOG(3) << "Converted " << it->first << " to " << it->second;
166 167
  } else {
    VLOG(3) << "Fused op for " << op->Type() << " is not implemented yet.";
S
Sławomir Siwek 已提交
168 169 170
  }
}

B
baoachun 已提交
171 172 173
}  // namespace ir
}  // namespace framework
}  // namespace paddle