dot_merger_test.cc 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "paddle/cinn/frontend/decomposer/test_helper.h"

namespace cinn {
namespace frontend {

22 23 24
int GetSize(std::vector<int>& shape) {
  return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
}
25 26 27 28 29 30 31 32

void RunModelTest(Program& program,
                  const std::vector<Variable>&& inputs,
                  const std::unordered_set<std::string>& fetch_ids) {
  // init input data.
  std::vector<std::vector<float>> inputs_data;
  for (auto input : inputs) {
    inputs_data.emplace_back(GetSize(input->shape));
33 34
    InitRandomVector<float>(
        &inputs_data.back(), inputs_data.back().size(), 0.0f, 1.0f, 1e-3);
35 36 37
  }

  auto target = common::DefaultTarget();
38 39 40
  std::unordered_map<std::string,
                     std::pair<std::vector<float>, std::vector<float>>>
      outputs;
41
  {
42 43
    auto graph =
        std::make_shared<hlir::framework::Graph>(program, fetch_ids, target);
44 45 46 47 48 49 50 51 52 53
    hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
    hlir::framework::ApplyPass(graph.get(), "FusionMergePass");

    auto scope = BuildScope(target, graph);
    hlir::framework::GraphCompiler gc(target, scope, graph);
    auto run_program = gc.Build();

    for (int idx = 0; idx < inputs.size(); ++idx) {
      scope->Var<hlir::framework::Tensor>(inputs[idx]->id);
      auto tensor = scope->GetTensor(inputs[idx]->id);
54
      auto* data = tensor->mutable_data<float>(target);
55 56 57 58 59 60 61
      CopyFromVector(inputs_data[idx], tensor, target);
    }
    run_program->Execute();
    for (auto id : fetch_ids) {
      auto tensor = scope->GetTensor(id);
      std::vector<float> data(tensor->shape().numel());
      CopyToVector(tensor, &data);
62 63
      outputs[id] = std::pair<std::vector<float>, std::vector<float>>(
          data, std::vector<float>());
64 65 66
    }
  }
  {
67 68
    auto graph =
        std::make_shared<hlir::framework::Graph>(program, fetch_ids, target);
69 70 71 72 73 74 75 76 77 78 79
    hlir::framework::ApplyPass(graph.get(), "DotMerger");
    hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
    hlir::framework::ApplyPass(graph.get(), "FusionMergePass");

    auto scope = BuildScope(target, graph);
    hlir::framework::GraphCompiler gc(target, scope, graph);
    auto run_program = gc.Build();

    for (int idx = 0; idx < inputs.size(); ++idx) {
      scope->Var<hlir::framework::Tensor>(inputs[idx]->id);
      auto tensor = scope->GetTensor(inputs[idx]->id);
80
      auto* data = tensor->mutable_data<float>(target);
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
      CopyFromVector(inputs_data[idx], tensor, target);
    }
    run_program->Execute();
    for (auto id : fetch_ids) {
      auto tensor = scope->GetTensor(id);
      std::vector<float> data(tensor->shape().numel());
      CopyToVector(tensor, &data);
      outputs[id].second = data;
    }
  }

  for (auto& output : outputs) {
    CheckOutput<float>(output.second.first, output.second.second, 1e-8, 1e-4);
  }
}

TEST(DotMerger, Test_dot_merger0) {
  int m = 2, k = 1024, n = 100, n1 = 100, n2 = 100, axis = 1;
  NetBuilder net_builder("Test_dot_merger0");
100 101 102 103 104 105 106 107 108 109 110 111 112
  auto A = net_builder.CreateInput(Float(32), {m, k}, "A");
  auto B = net_builder.CreateInput(Float(32), {k, n1}, "B");
  auto C = net_builder.CreateInput(Float(32), {k, n2}, "C");
  auto D = net_builder.CreateInput(Float(32), {n1, k}, "D");
  auto E = net_builder.CreateInput(Float(32), {n2, k}, "E");
  auto F = net_builder.CreateInput(Float(32), {k, n}, "F");
  auto G = net_builder.Matmul(A, B);
  auto H = net_builder.Matmul(A, C);
  auto G1 = net_builder.Matmul(D, F);
  auto H1 = net_builder.Matmul(E, F);
  auto G2 = net_builder.Concat({G, H}, axis);
  auto H2 = net_builder.Concat({G1, H1}, (1 - axis));
  auto F1 = net_builder.Matmul(G2, H2);
113
  auto fetch_ids = {F1->id};
114
  auto program = net_builder.Build();
115 116 117 118 119
  std::cout << "RunModelTest" << std::endl;
  RunModelTest(program, {A, B, C, D, E, F}, fetch_ids);
}

}  // namespace frontend
120
}  // namespace cinn