test_helper.h 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <gtest/gtest.h>

#include <random>

#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/pass/use_program_pass.h"
#include "paddle/cinn/frontend/program_pass.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"

namespace cinn::frontend {

template <typename T>
std::vector<T> GeneratedRandomVector(size_t numel) {
  std::vector<T> data(numel);

  std::random_device seed;
  std::default_random_engine engine(seed());
  std::uniform_real_distribution<float> dist(0.f, 10.f);
  for (size_t i = 0; i < numel; i++) {
    data[i] = static_cast<T>(dist(engine));  // All random data
  }
  return data;
}

template <typename T>
45 46 47
void CopyFromVector(const std::vector<T>& src,
                    hlir::framework::Tensor tensor,
                    Target target) {
48
  size_t numel = tensor->shape().numel();
49
  auto* dst = tensor->mutable_data<T>(target);
50 51 52 53 54 55 56 57 58 59 60

#ifdef CINN_WITH_CUDA
  cudaMemcpy(dst, src.data(), numel * sizeof(T), cudaMemcpyHostToDevice);
#else
  std::copy(src.begin(), src.end(), dst);
#endif
}

template <typename T>
std::vector<T> CopyToVector(const hlir::framework::Tensor tensor) {
  size_t numel = tensor->shape().numel();
61
  auto* src = tensor->data<T>();
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

  std::vector<T> dst(numel);
#ifdef CINN_WITH_CUDA
  cudaMemcpy(dst.data(), src, numel * sizeof(T), cudaMemcpyDeviceToHost);
#else
  for (size_t i = 0; i < numel; ++i) {
    dst[i] = src[i];
  }
#endif
  return dst;
}

class PassTest {
 public:
  PassTest() { target_ = common::DefaultTarget(); }

  int RunAndCheck(NetBuilder& builder,
                  const std::vector<std::string>& program_passes,
                  const std::vector<std::string>& input_names,
                  const std::vector<std::string>& output_names) {
    auto program = builder.Build();
    CHECK(IsValid(program)) << "The origin program is not valid.";
    int origin_program_size = program.size();
    LOG(INFO) << "Run origin program";
86 87
    std::unordered_map<std::string, std::vector<float>> origin_outputs =
        Execute(program, input_names, output_names);
88

89 90
    std::unordered_set<std::string> fetch_var_ids(output_names.begin(),
                                                  output_names.end());
91 92 93 94
    ProgramPass::Apply(&program, fetch_var_ids, target_, program_passes);
    int optimized_program_size = program.size();
    CHECK(IsValid(program)) << "The optimized program is not valid.";
    LOG(INFO) << "Run optimized program";
95 96
    std::unordered_map<std::string, std::vector<float>> optimized_outputs =
        Execute(program, input_names, output_names);
97 98 99 100 101 102 103 104 105 106 107

    for (auto name : output_names) {
      LOG(INFO) << "Check output name=" << name;
      CHECK(origin_outputs.count(name));
      CHECK(optimized_outputs.count(name));
      CheckOutput(optimized_outputs[name], origin_outputs[name]);
    }
    return origin_program_size - optimized_program_size;
  }

 protected:
108 109 110 111
  std::unordered_map<std::string, std::vector<float>> Execute(
      const Program& program,
      const std::vector<std::string>& input_names,
      const std::vector<std::string>& output_names) {
112
    LOG(INFO) << program;
113 114 115 116
    std::unordered_set<std::string> fetch_var_ids(output_names.begin(),
                                                  output_names.end());
    auto graph = std::make_shared<hlir::framework::Graph>(
        program, fetch_var_ids, target_);
117 118 119 120 121 122
    hlir::framework::ApplyPasses(graph.get(), DefaultOpFusionPasses());

    auto scope = hlir::framework::BuildScope(target_, graph);
    hlir::framework::GraphCompiler gc(target_, scope, graph);
    hlir::framework::GraphCompiler::CompileOptions options;
    options.with_instantiate_variables = true;
123 124
    auto result = gc.Build(options, std::move(fetch_var_ids));
    auto runtime_program = std::move(result.runtime_program);
125 126 127 128 129 130 131 132

    for (auto& name : input_names) {
      SetInputTensor(name, scope);
    }
    runtime_program->Execute();

    std::unordered_map<std::string, std::vector<float>> outputs;
    for (auto& name : output_names) {
133
      auto tensor = scope->GetTensor(name);
134 135 136 137 138 139
      std::vector<float> vec = CopyToVector<float>(tensor);
      outputs.emplace(name, vec);
    }
    return outputs;
  }

140 141
  void SetInputTensor(const std::string& name,
                      std::shared_ptr<hlir::framework::Scope> scope) {
142 143 144 145
    scope->Var<hlir::framework::Tensor>(name);
    auto tensor = scope->GetTensor(name);

    if (!inputs_.count(name)) {
146 147
      std::vector<float> vec =
          GeneratedRandomVector<float>(tensor->shape().numel());
148 149 150 151 152 153
      inputs_.emplace(name, vec);
    }
    auto iter = inputs_.find(name);
    CopyFromVector<float>(iter->second, tensor, target_);
  }

154 155
  void CheckOutput(const std::vector<float>& actual,
                   const std::vector<float>& expect) {
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
    CHECK_EQ(actual.size(), expect.size());
    for (size_t i = 0; i < expect.size(); ++i) {
      ASSERT_FLOAT_EQ(actual[i], expect[i]);
    }
  }

  bool IsValid(const Program& program) {
    std::unordered_set<std::string> inputs;
    for (auto& var : program.GetInputs()) {
      inputs.insert(var->id);
    }

    std::unordered_set<std::string> outputs;
    for (int i = 0; i < program.size(); ++i) {
      const auto& instr = program[i];
      for (auto& var : instr->outputs) {
        outputs.insert(var->id);
      }
    }

    bool valid = true;
    for (int i = 0; i < program.size(); ++i) {
      const auto& instr = program[i];
      // The inputs should be feeded, or other instructions' output.
      for (auto& var : instr->inputs) {
        if (!inputs.count(var->id) && !outputs.count(var->id)) {
182 183
          LOG(INFO) << "The input " << var->id << " of " << i
                    << "-th instrution (" << instr
184 185 186 187 188 189 190 191 192 193 194 195 196 197
                    << ") is not the output of any other instructions.";
          valid = false;
        }
      }
    }

    return valid;
  }

  Target target_;
  std::unordered_map<std::string, std::vector<float>> inputs_;
};

}  // namespace cinn::frontend