未验证 提交 a9b1e887 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] Printer for Program/Operation/Type and Register ops for ResNet50 (#53988)


* add conv2d

* printer draft

* fix bug

* printer draft finish

* fix windows CI

* commit printer and resnet50 related ops

* fix

* fix

* fix op definition

---------
Co-authored-by: Numiswing <umiswing@foxmail.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 bc9b6e26
......@@ -14,7 +14,9 @@
#include "paddle/fluid/dialect/pd_dialect.h"
#include "paddle/fluid/dialect/pd_attribute.h"
#include "paddle/fluid/dialect/pd_op.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/pd_type_storage.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
......@@ -93,6 +95,47 @@ void PaddleDialect::initialize() {
RegisterTypes<GET_PD_DIALECT_TYPE_LIST>();
RegisterAttributes<GET_PD_DIALECT_ATTRIBUTE_LIST>();
RegisterInterfaces<ParameterConvertInterface>();
RegisterOps<Conv2DOp,
FeedOp,
BatchNormOp,
BatchNormOp_,
ReluOp,
ElementwiseAddOp,
Pool2DOp,
FlattenContiguousRangeOp,
MatmulV2Op,
Reshape2Op,
SoftmaxWithCrossEntropyOp,
ReduceMeanOp,
TopKV2Op,
AccuracyOp,
ScaleOp,
FillConstantOp,
ReduceMeanGradOp,
SoftmaxWithCrossEntropyGradOp,
ElementwiseAddGradOp,
MatmulV2GradOp,
FlattenContiguousRangeGradOp,
Pool2DGradOp,
ReluGradOp,
BatchNormGradOp,
Conv2DGradOp,
SumOp,
MergedMomentumOp_,
FetchV2Op>();
}
void PaddleDialect::PrintType(ir::Type type, std::ostream& os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<";
auto& dims = tensor_type.dim();
for (auto d : dims) {
os << d;
os << "x";
}
tensor_type.dtype().print(os);
os << ">";
}
} // namespace dialect
......
......@@ -39,6 +39,8 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; }
void PrintType(ir::Type type, std::ostream& os);
private:
void initialize();
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/op_base.h"
namespace paddle {
namespace dialect {
#define OPNAME(op_name) "pd." #op_name
#define REIGSTER_EMPTY_OP(op_name, className) \
class className : public ir::Op<className> { \
public: \
static const char *name() { return OPNAME(op_name); } \
static const char **attributes_name; \
static constexpr uint32_t attributes_num = 0; \
}; \
const char **className::attributes_name = nullptr;
REIGSTER_EMPTY_OP(conv2d, Conv2DOp);
REIGSTER_EMPTY_OP(feed, FeedOp);
REIGSTER_EMPTY_OP(batch_norm, BatchNormOp);
REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_);
REIGSTER_EMPTY_OP(relu, ReluOp);
REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp);
REIGSTER_EMPTY_OP(pool2d, Pool2DOp);
REIGSTER_EMPTY_OP(flatten_contiguous_range, FlattenContiguousRangeOp);
REIGSTER_EMPTY_OP(matmul_v2, MatmulV2Op);
REIGSTER_EMPTY_OP(reshape2, Reshape2Op);
REIGSTER_EMPTY_OP(softmax_with_cross_entropy, SoftmaxWithCrossEntropyOp);
REIGSTER_EMPTY_OP(reduce_mean, ReduceMeanOp);
REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op);
REIGSTER_EMPTY_OP(scale, ScaleOp);
REIGSTER_EMPTY_OP(accuracy, AccuracyOp);
REIGSTER_EMPTY_OP(fill_constant, FillConstantOp);
REIGSTER_EMPTY_OP(reduce_mean_grad, ReduceMeanGradOp);
REIGSTER_EMPTY_OP(softmax_with_cross_entropy_grad,
SoftmaxWithCrossEntropyGradOp);
REIGSTER_EMPTY_OP(elementwise_add_grad, ElementwiseAddGradOp);
REIGSTER_EMPTY_OP(matmul_v2_grad, MatmulV2GradOp);
REIGSTER_EMPTY_OP(flatten_contiguous_range_grad, FlattenContiguousRangeGradOp);
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp);
REIGSTER_EMPTY_OP(relu_grad, ReluGradOp);
REIGSTER_EMPTY_OP(batch_norm_grad, BatchNormGradOp);
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp);
REIGSTER_EMPTY_OP(sum, SumOp);
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op);
REIGSTER_EMPTY_OP(merged_momentum_, MergedMomentumOp_);
} // namespace dialect
} // namespace paddle
......@@ -14,6 +14,8 @@
#pragma once
#include <ostream>
#include "paddle/ir/attribute_base.h"
#include "paddle/ir/dialect_interface.h"
#include "paddle/ir/ir_context.h"
......@@ -127,6 +129,10 @@ class Dialect {
return *interface;
}
virtual void PrintType(ir::Type type, std::ostream &os) {
throw std::logic_error("dialect has no registered type printing hook");
}
private:
Dialect(const Dialect &) = delete;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <list>
#include <ostream>
#include <string>
#include <unordered_map>
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
#include "paddle/ir/operation.h"
#include "paddle/ir/program.h"
#include "paddle/ir/value.h"
namespace ir {
namespace {
constexpr char newline[] = "\n";
} // namespace
class Printer {
public:
explicit Printer(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) {
if (type.isa<ir::Float16Type>()) {
os << "f16";
} else if (type.isa<ir::Float32Type>()) {
os << "f32";
} else if (type.isa<ir::Float64Type>()) {
os << "f64";
} else if (type.isa<ir::Int16Type>()) {
os << "i16";
} else if (type.isa<ir::Int32Type>()) {
os << "i32";
} else if (type.isa<ir::Int64Type>()) {
os << "i64";
} else {
auto& dialect = type.dialect();
dialect.PrintType(type, os);
}
}
public:
std::ostream& os;
};
void Type::print(std::ostream& os) const {
if (!*this) {
os << "<!TypeNull>";
return;
}
Printer p(os);
p.PrintType(*this);
}
class ProgramPrinter : public Printer {
public:
explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {}
void Print(ir::Program& program) {
for (auto* op : program.ops()) {
PrintOperation(op);
os << newline;
}
}
template <typename ForwardIterator,
typename UnaryFunctor,
typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
void PrintValue(ir::Value v) {
const void* key = static_cast<const void*>(v.impl());
auto ret = aliases.find(key);
if (ret != aliases.end()) {
os << ret->second;
return;
}
std::string new_name = "%" + std::to_string(cur_var_number);
cur_var_number++;
aliases[key] = new_name;
os << new_name;
}
/// @brief print operation
/// @param op
/// @example
void PrintOperation(ir::Operation* op) {
PrintOpResult(op); // TODO(lyk): add API to get opresults directly
os << " = ";
os << "\"" << op->op_name() << "\"";
PrintOpOperands(op); // TODO(lyk): add API to get operands directly
PrintAttribute(op);
os << " : ";
// PrintOpSingature
PrintOperandsType(op);
os << " -> ";
PrintOpReturnType(op); // TODO(lyk): add API to get opresults directly
}
void PrintOpResult(ir::Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<ir::OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
}
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](ir::Value v) { this->PrintValue(v); },
[this]() { this->os << ","; });
os << ") ";
}
void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE } "; }
void PrintOpOperands(ir::Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<ir::Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).impl()->source());
}
PrintInterleave(
op_operands.begin(),
op_operands.end(),
[this](ir::Value v) { this->PrintValue(v); },
[this]() { this->os << ","; });
os << ") ";
}
void PrintOperandsType(ir::Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<ir::Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operand_types.push_back(
op->GetOperandByIndex(idx).impl()->source().type());
}
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](ir::Type t) { this->PrintType(t); },
[this]() { this->os << ","; });
}
void PrintOpReturnType(ir::Operation* op) {
auto num_op_result = op->num_results();
std::vector<ir::Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_result_types.push_back(op->GetResultByIndex(idx).type());
}
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](ir::Type t) { this->PrintType(t); },
[this]() { this->os << ","; });
}
private:
size_t cur_var_number;
std::unordered_map<const void*, std::string> aliases;
};
std::ostream& operator<<(std::ostream& os, Program& program) {
ProgramPrinter printer(os);
printer.Print(program);
return os;
}
} // namespace ir
......@@ -56,4 +56,6 @@ class Program {
std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_;
};
std::ostream& operator<<(std::ostream& os, Program& program);
} // namespace ir
......@@ -14,6 +14,8 @@
#pragma once
#include <ostream>
#include "paddle/ir/cast_utils.h"
#include "paddle/ir/type_base.h"
......@@ -76,6 +78,8 @@ class Type {
return ir::dyn_cast<U>(*this);
}
void print(std::ostream &os) const;
///
/// \brief Enable hashing Type.
///
......
......@@ -67,6 +67,10 @@ class ValueUseIterator {
bool operator==(const ValueUseIterator<OperandType> &rhs) const {
return current_ == rhs.current_;
}
bool operator!=(const ValueUseIterator<OperandType> &rhs) const {
return !(*this == rhs);
}
ir::Operation *owner() const { return current_.impl()->owner(); }
OperandType get() const { return current_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册