pass_test_helper.h 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>
#include <gtest/gtest.h>

#include <algorithm>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <unordered_set>
#include <vector>
#ifdef CINN_WITH_CUDA
#include <cuda_runtime.h>
#endif

#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"

DECLARE_bool(cinn_use_op_fusion);

namespace cinn {
namespace frontend {

inline void PrintMatrix(const std::vector<float>& mat, int bs, int m, int n) {
  if (!VLOG_IS_ON(5)) {
    return;
  }
  const auto min_max = std::minmax_element(mat.begin(), mat.end());
53 54 55 56
  int min = static_cast<int>(*min_max.first);
  int max = static_cast<int>(*min_max.second);
  auto ele_width =
      std::max(std::to_string(min).length(), std::to_string(max).length());
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
  std::cout << "\n" << std::string((ele_width + 2) * n - 1, '-') << "\n";
  for (int b = 0; b < bs; b++) {
    for (int i = 0; i < m; i++) {
      for (int j = 0; j < n; j++) {
        std::cout << std::setw(ele_width) << mat[b * m * n + i * n + j] << ", ";
      }
      std::cout << "\n";
    }
    if (b != bs - 1) {
      std::cout << std::string((ele_width + 2) * n - 1, '*') << "\n";
    }
  }
  std::cout << std::string((ele_width + 2) * n - 1, '-') << "\n\n";
}

inline void RunGraph(std::shared_ptr<hlir::framework::Graph> graph,
                     const common::Target& target,
                     const std::shared_ptr<hlir::framework::Scope>& scope,
                     const std::vector<std::string>& output_ids,
                     const std::vector<std::string>& graph_passes) {
  hlir::framework::ApplyPasses(graph.get(), graph_passes);
  VLOG(3) << "Graph Viz:\n" << graph->Visualize();
  BuildScope(target, graph, scope);
  hlir::framework::GraphCompiler::CompileOptions options;
81
  options.attached_code = "";
82 83
  options.with_instantiate_variables = true;
  hlir::framework::GraphCompiler gc(target, scope, graph);
84 85 86 87
  auto runtime_program = gc.Build(options,
                                  std::unordered_set<std::string>(
                                      output_ids.begin(), output_ids.end()))
                             .runtime_program;
88 89 90
  runtime_program->Execute();
}

91 92 93 94 95 96 97 98 99 100 101 102
inline std::vector<float> RunProgram(
    const Program& program,
    const common::Target& target,
    const std::vector<std::string>& input_ids,
    const std::vector<std::string>& output_ids,
    const std::vector<std::string>& graph_passes,
    int seed = -1,
    bool print_tensor = false) {
  std::unordered_set<std::string> outputs_set{output_ids.begin(),
                                              output_ids.end()};
  auto graph =
      std::make_shared<hlir::framework::Graph>(program, outputs_set, target);
103 104 105 106 107 108 109 110
  auto scope = hlir::framework::BuildScope(target, graph);
  for (auto& input_id : input_ids) {
    scope->Var<hlir::framework::Tensor>(input_id);
    auto input_tensor = scope->GetTensor(input_id);
    SetRandData<int>(input_tensor, target, seed);
    if (print_tensor) {
      auto tensor_data = GetTensorData<float>(input_tensor, target);
      if (input_tensor->shape().data().size() == 2) {
111 112 113 114
        PrintMatrix(tensor_data,
                    1,
                    input_tensor->shape().data()[0],
                    input_tensor->shape().data()[1]);
115 116 117 118 119 120 121 122 123 124 125 126
      } else if (input_tensor->shape().data().size() == 3) {
        PrintMatrix(tensor_data,
                    input_tensor->shape().data()[0],
                    input_tensor->shape().data()[1],
                    input_tensor->shape().data()[2]);
      }
    }
  }

  RunGraph(graph, target, scope, output_ids, graph_passes);

  auto output_tensor = scope->GetTensor(output_ids.front());
127
  auto output_data = GetTensorData<float>(output_tensor, target);
128 129
  if (print_tensor) {
    if (output_tensor->shape().data().size() == 2) {
130 131 132 133
      PrintMatrix(output_data,
                  1,
                  output_tensor->shape().data()[0],
                  output_tensor->shape().data()[1]);
134 135 136 137 138 139 140 141 142 143 144 145
    } else if (output_tensor->shape().data().size() == 3) {
      PrintMatrix(output_data,
                  output_tensor->shape().data()[0],
                  output_tensor->shape().data()[1],
                  output_tensor->shape().data()[2]);
    }
  }
  return output_data;
}

struct OptimizeConfig {
  struct PassGroup;
146
  explicit OptimizeConfig(const PassGroup& program_passes)
147
      : program_passes{program_passes} {
148
    if (FLAGS_cinn_use_op_fusion) {
149 150
      graph_passes = {{"OpFusionPass", "FusionMergePass"},
                      {"OpFusionPass", "FusionMergePass"}};
151 152 153 154 155
    }
  }
  OptimizeConfig(const PassGroup& program_passes, const PassGroup& graph_passes)
      : program_passes{program_passes}, graph_passes{graph_passes} {}

156 157
  OptimizeConfig(const std::pair<std::vector<std::string>,
                                 std::vector<std::string>>& program_passes) {
158
    this->program_passes.ctrl = program_passes.first;
159
    this->program_passes.exp = program_passes.second;
160 161

    if (FLAGS_cinn_use_op_fusion) {
162 163 164
      graph_passes = {
          {"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"},
          {"TransToCustomCallPass", "OpFusionPass", "FusionMergePass"}};
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
    }
  }

  struct PassGroup {
    // control group
    std::vector<std::string> ctrl;
    // experimental group
    std::vector<std::string> exp;
  };
  PassGroup program_passes;
  PassGroup graph_passes;
};

inline void CompareResult(Program* program,
                          const common::Target& target,
                          const std::vector<std::string>& input_ids,
                          const std::vector<std::string>& output_ids,
                          size_t size_diff,
                          const OptimizeConfig& passes,
184
                          int seed = -1,
185
                          bool print_tensor = false) {
186 187
  std::unordered_set<std::string> fetch_ids(output_ids.begin(),
                                            output_ids.end());
188 189 190 191 192
  // apply common passes
  ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.ctrl);
  // get original program size
  auto origin_size = program->size();
  // get original output
193 194 195 196 197 198 199
  auto origin_out = RunProgram(*program,
                               target,
                               input_ids,
                               output_ids,
                               passes.graph_passes.ctrl,
                               seed,
                               print_tensor);
200 201 202 203 204 205 206 207

  // apply fused passes
  ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.exp);

  // get fused program size
  auto fused_size = program->size();
  ASSERT_EQ(size_diff, origin_size - fused_size);
  // get fused output
208 209 210 211 212 213 214
  auto fused_out = RunProgram(*program,
                              target,
                              input_ids,
                              output_ids,
                              passes.graph_passes.exp,
                              seed,
                              print_tensor);
215 216 217 218 219 220 221

  ASSERT_EQ(origin_out.size(), fused_out.size());
  for (size_t i = 0; i < origin_out.size(); ++i) {
    ASSERT_FLOAT_EQ(origin_out[i], fused_out[i]) << " i is " << i;
  }
}

222 223 224 225 226 227
inline bool CompareProgramPassResult(
    Program* program,
    const common::Target& target,
    const std::unordered_set<std::string>& fetch_ids,
    const size_t size_diff,
    const OptimizeConfig& passes) {
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
  // apply common passes
  ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.ctrl);
  // get original program size
  auto origin_size = program->size();

  // apply fused passes
  ProgramPass::Apply(program, fetch_ids, target, passes.program_passes.exp);

  // get fused program size
  auto fused_size = program->size();
  return size_diff == (origin_size - fused_size);
}

}  // namespace frontend
}  // namespace cinn