From a9b1e887f7a82cdb4ec19688a38d3497c212cc7b Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 22 May 2023 15:03:44 +0800 Subject: [PATCH] [NewIR] Printer for Program/Operation/Type and Register ops for ResNet50 (#53988) * add conv2d * printer draft * fix bug * printer draft finish * fix windows CI * commit printer and resnet50 related ops * fix * fix * fix op definition --------- Co-authored-by: umiswing Co-authored-by: zhangbo9674 --- paddle/fluid/dialect/pd_dialect.cc | 43 ++++++ paddle/fluid/dialect/pd_dialect.h | 2 + paddle/fluid/dialect/pd_op.h | 64 +++++++++ paddle/ir/dialect.h | 6 + paddle/ir/printer.cc | 203 +++++++++++++++++++++++++++++ paddle/ir/program.h | 2 + paddle/ir/type.h | 4 + paddle/ir/value.h | 4 + 8 files changed, 328 insertions(+) create mode 100644 paddle/fluid/dialect/pd_op.h create mode 100644 paddle/ir/printer.cc diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 4439110adcb..c7235060006 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -14,7 +14,9 @@ #include "paddle/fluid/dialect/pd_dialect.h" #include "paddle/fluid/dialect/pd_attribute.h" +#include "paddle/fluid/dialect/pd_op.h" #include "paddle/fluid/dialect/pd_type.h" +#include "paddle/fluid/dialect/pd_type_storage.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/data_type.h" @@ -93,6 +95,47 @@ void PaddleDialect::initialize() { RegisterTypes(); RegisterAttributes(); RegisterInterfaces(); + RegisterOps(); +} + +void PaddleDialect::PrintType(ir::Type type, std::ostream& os) { + DenseTensorType tensor_type = type.dyn_cast(); + + os << "tensor<"; + auto& dims = tensor_type.dim(); + for (auto d : dims) { + os << d; + os << "x"; + } + tensor_type.dtype().print(os); + os << ">"; } } // namespace dialect diff --git a/paddle/fluid/dialect/pd_dialect.h b/paddle/fluid/dialect/pd_dialect.h index a81ff7cd48c..6b1312e6913 100644 --- a/paddle/fluid/dialect/pd_dialect.h +++ b/paddle/fluid/dialect/pd_dialect.h @@ -39,6 +39,8 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } + void PrintType(ir::Type type, std::ostream& os); + private: void initialize(); }; diff --git a/paddle/fluid/dialect/pd_op.h b/paddle/fluid/dialect/pd_op.h new file mode 100644 index 00000000000..344efcc9e95 --- /dev/null +++ b/paddle/fluid/dialect/pd_op.h @@ -0,0 +1,64 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/op_base.h" + +namespace paddle { +namespace dialect { + +#define OPNAME(op_name) "pd." #op_name + +#define REIGSTER_EMPTY_OP(op_name, className) \ + class className : public ir::Op { \ + public: \ + static const char *name() { return OPNAME(op_name); } \ + static const char **attributes_name; \ + static constexpr uint32_t attributes_num = 0; \ + }; \ + const char **className::attributes_name = nullptr; + +REIGSTER_EMPTY_OP(conv2d, Conv2DOp); +REIGSTER_EMPTY_OP(feed, FeedOp); +REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); +REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); +REIGSTER_EMPTY_OP(relu, ReluOp); +REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp); +REIGSTER_EMPTY_OP(pool2d, Pool2DOp); +REIGSTER_EMPTY_OP(flatten_contiguous_range, FlattenContiguousRangeOp); +REIGSTER_EMPTY_OP(matmul_v2, MatmulV2Op); +REIGSTER_EMPTY_OP(reshape2, Reshape2Op); +REIGSTER_EMPTY_OP(softmax_with_cross_entropy, SoftmaxWithCrossEntropyOp); +REIGSTER_EMPTY_OP(reduce_mean, ReduceMeanOp); +REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); +REIGSTER_EMPTY_OP(scale, ScaleOp); +REIGSTER_EMPTY_OP(accuracy, AccuracyOp); +REIGSTER_EMPTY_OP(fill_constant, FillConstantOp); +REIGSTER_EMPTY_OP(reduce_mean_grad, ReduceMeanGradOp); +REIGSTER_EMPTY_OP(softmax_with_cross_entropy_grad, + SoftmaxWithCrossEntropyGradOp); +REIGSTER_EMPTY_OP(elementwise_add_grad, ElementwiseAddGradOp); +REIGSTER_EMPTY_OP(matmul_v2_grad, MatmulV2GradOp); +REIGSTER_EMPTY_OP(flatten_contiguous_range_grad, FlattenContiguousRangeGradOp); +REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); +REIGSTER_EMPTY_OP(relu_grad, ReluGradOp); +REIGSTER_EMPTY_OP(batch_norm_grad, BatchNormGradOp); +REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); +REIGSTER_EMPTY_OP(sum, SumOp); +REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); +REIGSTER_EMPTY_OP(merged_momentum_, MergedMomentumOp_); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/ir/dialect.h b/paddle/ir/dialect.h index 9fdb931f733..18870d0b049 100644 --- a/paddle/ir/dialect.h +++ b/paddle/ir/dialect.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/ir/attribute_base.h" #include "paddle/ir/dialect_interface.h" #include "paddle/ir/ir_context.h" @@ -127,6 +129,10 @@ class Dialect { return *interface; } + virtual void PrintType(ir::Type type, std::ostream &os) { + throw std::logic_error("dialect has no registered type printing hook"); + } + private: Dialect(const Dialect &) = delete; diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc new file mode 100644 index 00000000000..b4d3acb930a --- /dev/null +++ b/paddle/ir/printer.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/ir/builtin_attribute.h" +#include "paddle/ir/builtin_type.h" +#include "paddle/ir/dialect.h" +#include "paddle/ir/operation.h" +#include "paddle/ir/program.h" +#include "paddle/ir/value.h" + +namespace ir { + +namespace { +constexpr char newline[] = "\n"; +} // namespace + +class Printer { + public: + explicit Printer(std::ostream& os) : os(os) {} + + void PrintType(ir::Type type) { + if (type.isa()) { + os << "f16"; + } else if (type.isa()) { + os << "f32"; + } else if (type.isa()) { + os << "f64"; + } else if (type.isa()) { + os << "i16"; + } else if (type.isa()) { + os << "i32"; + } else if (type.isa()) { + os << "i64"; + } else { + auto& dialect = type.dialect(); + dialect.PrintType(type, os); + } + } + + public: + std::ostream& os; +}; + +void Type::print(std::ostream& os) const { + if (!*this) { + os << ""; + return; + } + Printer p(os); + p.PrintType(*this); +} + +class ProgramPrinter : public Printer { + public: + explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {} + + void Print(ir::Program& program) { + for (auto* op : program.ops()) { + PrintOperation(op); + os << newline; + } + } + + template + void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } + } + + void PrintValue(ir::Value v) { + const void* key = static_cast(v.impl()); + auto ret = aliases.find(key); + if (ret != aliases.end()) { + os << ret->second; + return; + } + + std::string new_name = "%" + std::to_string(cur_var_number); + cur_var_number++; + aliases[key] = new_name; + os << new_name; + } + + /// @brief print operation + /// @param op + /// @example + void PrintOperation(ir::Operation* op) { + PrintOpResult(op); // TODO(lyk): add API to get opresults directly + os << " = "; + + os << "\"" << op->op_name() << "\""; + PrintOpOperands(op); // TODO(lyk): add API to get operands directly + + PrintAttribute(op); + os << " : "; + + // PrintOpSingature + PrintOperandsType(op); + os << " -> "; + PrintOpReturnType(op); // TODO(lyk): add API to get opresults directly + } + + void PrintOpResult(ir::Operation* op) { + os << " ("; + auto num_op_result = op->num_results(); + std::vector op_results; + op_results.reserve(num_op_result); + for (size_t idx = 0; idx < num_op_result; idx++) { + op_results.push_back(op->GetResultByIndex(idx)); + } + PrintInterleave( + op_results.begin(), + op_results.end(), + [this](ir::Value v) { this->PrintValue(v); }, + [this]() { this->os << ","; }); + os << ") "; + } + + void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE } "; } + + void PrintOpOperands(ir::Operation* op) { + os << " ("; + auto num_op_operands = op->num_operands(); + std::vector op_operands; + op_operands.reserve(num_op_operands); + for (size_t idx = 0; idx < num_op_operands; idx++) { + op_operands.push_back(op->GetOperandByIndex(idx).impl()->source()); + } + PrintInterleave( + op_operands.begin(), + op_operands.end(), + [this](ir::Value v) { this->PrintValue(v); }, + [this]() { this->os << ","; }); + os << ") "; + } + + void PrintOperandsType(ir::Operation* op) { + auto num_op_operands = op->num_operands(); + std::vector op_operand_types; + op_operand_types.reserve(num_op_operands); + for (size_t idx = 0; idx < num_op_operands; idx++) { + op_operand_types.push_back( + op->GetOperandByIndex(idx).impl()->source().type()); + } + PrintInterleave( + op_operand_types.begin(), + op_operand_types.end(), + [this](ir::Type t) { this->PrintType(t); }, + [this]() { this->os << ","; }); + } + + void PrintOpReturnType(ir::Operation* op) { + auto num_op_result = op->num_results(); + std::vector op_result_types; + op_result_types.reserve(num_op_result); + for (size_t idx = 0; idx < num_op_result; idx++) { + op_result_types.push_back(op->GetResultByIndex(idx).type()); + } + PrintInterleave( + op_result_types.begin(), + op_result_types.end(), + [this](ir::Type t) { this->PrintType(t); }, + [this]() { this->os << ","; }); + } + + private: + size_t cur_var_number; + std::unordered_map aliases; +}; + +std::ostream& operator<<(std::ostream& os, Program& program) { + ProgramPrinter printer(os); + printer.Print(program); + return os; +} + +} // namespace ir diff --git a/paddle/ir/program.h b/paddle/ir/program.h index 8b0a54d77c3..bcae617b2b9 100644 --- a/paddle/ir/program.h +++ b/paddle/ir/program.h @@ -56,4 +56,6 @@ class Program { std::unordered_map> parameters_; }; +std::ostream& operator<<(std::ostream& os, Program& program); + } // namespace ir diff --git a/paddle/ir/type.h b/paddle/ir/type.h index 37ceb39a687..fce17db82eb 100644 --- a/paddle/ir/type.h +++ b/paddle/ir/type.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/ir/cast_utils.h" #include "paddle/ir/type_base.h" @@ -76,6 +78,8 @@ class Type { return ir::dyn_cast(*this); } + void print(std::ostream &os) const; + /// /// \brief Enable hashing Type. /// diff --git a/paddle/ir/value.h b/paddle/ir/value.h index 14b2c11ad9e..423553f5aab 100644 --- a/paddle/ir/value.h +++ b/paddle/ir/value.h @@ -67,6 +67,10 @@ class ValueUseIterator { bool operator==(const ValueUseIterator &rhs) const { return current_ == rhs.current_; } + bool operator!=(const ValueUseIterator &rhs) const { + return !(*this == rhs); + } + ir::Operation *owner() const { return current_.impl()->owner(); } OperandType get() const { return current_; } -- GitLab