pass_manager_test.cc 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
16
#include "glog/logging.h"
17

18 19 20
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
21
#include "paddle/fluid/ir/interface/op_yaml_info.h"
22 23
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
24 25 26 27 28
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/operation.h"
29 30
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
31
#include "paddle/phi/kernels/elementwise_add_kernel.h"
32

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
#ifndef _WIN32
class TestAnalysis1 {};
class TestAnalysis2 {};

IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis1)
IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis1)
IR_DECLARE_EXPLICIT_TYPE_ID(TestAnalysis2)
IR_DEFINE_EXPLICIT_TYPE_ID(TestAnalysis2)

TEST(pass_manager, PreservedAnalyses) {
  ir::detail::PreservedAnalyses pa;
  CHECK_EQ(pa.IsNone(), true);

  CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), false);
  pa.Preserve<TestAnalysis1>();
  CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), true);
  pa.Unpreserve<TestAnalysis1>();
  CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), false);
  CHECK_EQ(pa.IsPreserved<TestAnalysis2>(), false);
  pa.Preserve<TestAnalysis1, TestAnalysis2>();
  CHECK_EQ(pa.IsPreserved<TestAnalysis1>(), true);
  CHECK_EQ(pa.IsPreserved<TestAnalysis2>(), true);
  CHECK_EQ(pa.IsAll(), false);
  pa.PreserveAll();
  CHECK_EQ(pa.IsAll(), true);
  CHECK_EQ(pa.IsNone(), false);
}
#endif

62
class AddOp : public ir::Op<AddOp> {
63 64
 public:
  using Op::Op;
65 66 67
  static const char *name() { return "test.add"; }
  static constexpr const char **attributes_name = nullptr;
  static constexpr uint32_t attributes_num = 0;
68
  void Verify();
69 70 71 72 73
  static void Build(ir::Builder &builder,             // NOLINT
                    ir::OperationArgument &argument,  // NOLINT
                    ir::OpResult l_operand,
                    ir::OpResult r_operand,
                    ir::Type sum_type);
74
};
75 76 77 78 79 80 81 82
void AddOp::Verify() {
  if (num_operands() != 2) {
    throw("The size of inputs must be equal to 2.");
  }
  if (num_results() != 1) {
    throw("The size of outputs must be equal to 1.");
  }
}
83 84 85 86 87 88 89 90 91
void AddOp::Build(ir::Builder &,
                  ir::OperationArgument &argument,
                  ir::OpResult l_operand,
                  ir::OpResult r_operand,
                  ir::Type sum_type) {
  argument.AddOperand(l_operand);
  argument.AddOperand(r_operand);
  argument.AddOutput(sum_type);
}
92 93
IR_DECLARE_EXPLICIT_TYPE_ID(AddOp)
IR_DEFINE_EXPLICIT_TYPE_ID(AddOp)
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
struct CountOpAnalysis {
  explicit CountOpAnalysis(ir::Operation *container_op) {
    IR_ENFORCE(container_op->num_regions() > 0, true);

    LOG(INFO) << "In CountOpAnalysis, op is " << container_op->name() << "\n";
    for (size_t i = 0; i < container_op->num_regions(); ++i) {
      auto &region = container_op->region(i);
      for (auto it = region.begin(); it != region.end(); ++it) {
        auto *block = *it;
        for (auto it = block->begin(); it != block->end(); ++it) {
          ++count;
        }
      }
    }

    LOG(INFO) << "-- count is " << count << "\n";
  }

  int count = 0;
};

IR_DECLARE_EXPLICIT_TYPE_ID(CountOpAnalysis)
IR_DEFINE_EXPLICIT_TYPE_ID(CountOpAnalysis)

119 120 121 122
class TestPass : public ir::Pass {
 public:
  TestPass() : ir::Pass("TestPass", 1) {}
  void Run(ir::Operation *op) override {
123 124 125 126 127 128
    auto count_op_analysis = analysis_manager().GetAnalysis<CountOpAnalysis>();
    pass_state().preserved_analyses.Preserve<CountOpAnalysis>();
    CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
             true);
    CHECK_EQ(count_op_analysis.count, 4);

129 130 131 132 133
    auto module_op = op->dyn_cast<ir::ModuleOp>();
    CHECK_EQ(module_op.operation(), op);
    CHECK_EQ(module_op.name(), module_op->name());
    LOG(INFO) << "In " << pass_info().name << ": " << module_op->name()
              << std::endl;
134 135 136 137

    pass_state().preserved_analyses.Unpreserve<CountOpAnalysis>();
    CHECK_EQ(pass_state().preserved_analyses.IsPreserved<CountOpAnalysis>(),
             false);
138 139
  }

140 141
  bool CanApplyOn(ir::Operation *op) const override {
    return op->name() == "builtin.module" && op->num_regions() > 0;
142 143 144
  }
};

145 146 147 148 149
TEST(pass_manager, PassManager) {
  //
  // TODO(liuyuanle): remove test code other than pass manager
  //

150
  // (1) Init environment.
151
  ir::IrContext *ctx = ir::IrContext::Instance();
152 153 154 155 156 157 158 159 160 161 162
  ir::Dialect *builtin_dialect =
      ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
  builtin_dialect->RegisterOp<AddOp>();
  ir::Dialect *paddle_dialect =
      ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();

  // (2) Create an empty program object
  ir::Program program(ctx);

  // (3) Create a float32 DenseTensor Parameter and save into Program
  ir::Type fp32_dtype = ir::Float32Type::get(ctx);
163 164 165
  phi::DDim dims = {2, 2};
  phi::DataLayout data_layout = phi::DataLayout::NCHW;
  phi::LoD lod = {{0, 1, 2}};
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
  size_t offset = 0;
  ir::Type dense_tensor_dtype = paddle::dialect::DenseTensorType::get(
      ctx, fp32_dtype, dims, data_layout, lod, offset);

  std::vector<float> data_a = {1, 2, 3, 4};
  std::unique_ptr<ir::Parameter> parameter_a =
      std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_a.data()),
                                      4 * sizeof(float),
                                      dense_tensor_dtype);
  program.SetParameter("a", std::move(parameter_a));
  EXPECT_EQ(program.parameters_num() == 1, true);

  std::vector<float> data_b = {5, 6, 7, 8};
  std::unique_ptr<ir::Parameter> parameter_b =
      std::make_unique<ir::Parameter>(reinterpret_cast<void *>(data_b.data()),
                                      4 * sizeof(float),
                                      dense_tensor_dtype);
  program.SetParameter("b", std::move(parameter_b));
  EXPECT_EQ(program.parameters_num() == 2, true);

  // (4) Def a = GetParameterOp("a"), and create DenseTensor for a.
187 188
  ir::Builder builder(ctx, program.block());
  auto op1 = builder.Build<ir::GetParameterOp>("a", dense_tensor_dtype);
189 190

  EXPECT_EQ(&program, op1->GetParentProgram());
191
  EXPECT_EQ(op1->result(0).type().dialect().id(), paddle_dialect->id());
192
  using Interface = paddle::dialect::ParameterConvertInterface;
193 194
  Interface *a_interface =
      op1->result(0).type().dialect().GetRegisteredInterface<Interface>();
195 196 197 198 199 200 201 202 203 204 205 206 207 208
  std::shared_ptr<paddle::framework::Variable> a_var =
      a_interface->ParameterToVariable(program.GetParameter("a"));
  const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
  EXPECT_EQ(a_tensor.numel(), 4);
  EXPECT_EQ(a_tensor.dims(), dims);
  EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
  EXPECT_EQ(a_tensor.layout(), data_layout);
  EXPECT_EQ(a_tensor.lod(), lod);
  EXPECT_EQ(a_tensor.offset(), offset);
  for (int64_t i = 0; i < a_tensor.numel(); i++) {
    EXPECT_EQ(*(a_tensor.data<float>() + i), data_a[i]);
  }

  // (5) Def b = GetParameterOp("b"), and create DenseTensor for b.
209
  auto op2 = builder.Build<ir::GetParameterOp>("b", dense_tensor_dtype);
210 211 212
  EXPECT_EQ(op2->result(0).type().dialect().id(), paddle_dialect->id());
  Interface *b_interface =
      op2->result(0).type().dialect().GetRegisteredInterface<Interface>();
213 214 215 216 217 218 219 220 221 222 223 224 225 226
  std::shared_ptr<paddle::framework::Variable> b_var =
      b_interface->ParameterToVariable(program.GetParameter("b"));
  const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
  EXPECT_EQ(b_tensor.numel(), 4);
  EXPECT_EQ(b_tensor.dims(), dims);
  EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
  EXPECT_EQ(b_tensor.layout(), data_layout);
  EXPECT_EQ(b_tensor.lod(), lod);
  EXPECT_EQ(b_tensor.offset(), offset);
  for (int64_t i = 0; i < b_tensor.numel(); i++) {
    EXPECT_EQ(*(b_tensor.data<float>() + i), data_b[i]);
  }

  // (6) Def c = AddOp(a, b), execute this op.
227 228
  auto op3 =
      builder.Build<AddOp>(op1->result(0), op2->result(0), dense_tensor_dtype);
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
  phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
      paddle::platform::DeviceContextPool::Instance().Get(
          paddle::platform::CPUPlace()));
  phi::DenseTensor c_tensor =
      phi::Add<float, phi::CPUContext>(*dev_ctx, a_tensor, b_tensor);
  std::shared_ptr<paddle::framework::Variable> variable_c =
      std::make_shared<paddle::framework::Variable>();
  auto *dst_tensor = variable_c->GetMutable<phi::DenseTensor>();
  *dst_tensor = c_tensor;
  EXPECT_EQ(dst_tensor->numel(), b_tensor.numel());
  EXPECT_EQ(dst_tensor->dims(), b_tensor.dims());
  EXPECT_EQ(dst_tensor->dtype(), b_tensor.dtype());
  EXPECT_EQ(dst_tensor->layout(), b_tensor.layout());
  EXPECT_EQ(dst_tensor->lod(), b_tensor.lod());
  EXPECT_EQ(dst_tensor->offset(), b_tensor.offset());
  for (int64_t i = 0; i < dst_tensor->numel(); i++) {
    EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
  }

248 249
  // (7) Def SetParameterOp(c, "c")
  auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
250
  EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
251
            paddle_dialect->id());
252 253
  Interface *c_interface =
      op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
254 255
  //   ir::Parameter *parameter_c =
  //       c_interface->VariableToParameter(variable_c.get());
256

257 258 259 260 261 262 263 264 265 266 267 268 269
  std::unique_ptr<ir::Parameter> parameter_c =
      c_interface->VariableToParameter(variable_c.get());
  EXPECT_EQ(parameter_c->type(), dense_tensor_dtype);
  for (int64_t i = 0; i < dst_tensor->numel(); i++) {
    EXPECT_EQ(*(dst_tensor->data<float>() + i),
              *(static_cast<float *>(parameter_c->data()) + i));
  }
  program.SetParameter("c", std::move(parameter_c));

  // (8) Traverse Program
  EXPECT_EQ(program.block()->size() == 4, true);
  EXPECT_EQ(program.parameters_num() == 3, true);

270 271 272 273
  //
  // TODO(liuyuanle): remove the code above.
  //

274
  // (9) Test pass manager for program.
275
  ir::PassManager pm(ctx);
276

277
  pm.AddPass(std::make_unique<TestPass>());
278

279 280
  pm.EnableIRPrinting(std::make_unique<ir::PassManager::IRPrinterOption>(
      [](ir::Pass *pass, ir::Operation *op) {
281
        return pass->name() == "TestPass";
282 283
      },
      [](ir::Pass *pass, ir::Operation *op) {
284
        return pass->name() == "TestPass";
285 286
      },
      true,
287 288 289
      true));

  pm.EnablePassTiming(true);
290

291
  CHECK_EQ(pm.Run(&program), true);
292
}