未验证 提交 a2d61455 编写于 作者: X xingmingyyj 提交者: GitHub

[NewIR] Add parser to deserialize (#55695)

* add parser

* add parser

* add parser

* add parser

* add parser

* add parser

* add parser

* Update test/cpp/ir/core/ir_parser_test.cc
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>

* add parser

* add parser

* add parser

* Update test/cpp/ir/core/program_translator_test.cc
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>

* Update test/cpp/ir/core/program_translator_test.cc
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>

* Update dialect.h

* add parser

* add parser

* Update CMakeLists.txt

* add parser

---------
Co-authored-by: Nkangguangli <kangguangli@hotmail.com>
上级 6e17e661
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
const phi::IntArray& IntArrayAttribute::data() const { const phi::IntArray &IntArrayAttribute::data() const {
return storage()->GetAsKey(); return storage()->GetAsKey();
} }
...@@ -48,6 +48,109 @@ phi::Scalar ScalarAttribute::data() { ...@@ -48,6 +48,109 @@ phi::Scalar ScalarAttribute::data() {
} }
} }
IntArrayAttribute IntArrayAttribute::Parse(ir::IrParser &parser) { // NOLINT
Token buket_token = parser.ConsumeToken();
std::vector<int32_t> vec{};
while (parser.PeekToken().val_ != "]") {
Token val_token = parser.ConsumeToken();
vec.push_back(atoll(val_token.val_.c_str()));
if (parser.PeekToken().val_ == "]") break;
parser.ConsumeToken();
}
parser.ConsumeToken();
return IntArrayAttribute::get(parser.ctx, vec);
}
// Parse a DataTypeAttribute
// DataTypeAttribute := bool|uint8|int8|uint16|int16|uint32
// |int32|uint64|int64|float32|complex64
// |complex128|Undefined|psting|flaot16
// |bfloat16|num_data_types|all_dtype
DataTypeAttribute DataTypeAttribute::Parse(ir::IrParser &parser) { // NOLINT
std::unordered_map<std::string, phi::DataType> StringToDataType{
{"bool", phi::DataType::BOOL},
{"uint8", phi::DataType::UINT8},
{"int8", phi::DataType::INT8},
{"uint16", phi::DataType::UINT16},
{"int16", phi::DataType::INT16},
{"uint32", phi::DataType::UINT32},
{"int32", phi::DataType::INT32},
{"uint64", phi::DataType::UINT64},
{"int64", phi::DataType::INT64},
{"float32", phi::DataType::FLOAT32},
{"complex64", phi::DataType::COMPLEX64},
{"complex128", phi::DataType::COMPLEX128},
{"Undefined", phi::DataType::UNDEFINED},
{"psting", phi::DataType::PSTRING},
{"float16", phi::DataType::FLOAT16},
{"bfloat16", phi::DataType::BFLOAT16},
{"float64", phi::DataType::FLOAT64}};
std::string datatype_token_val = parser.ConsumeToken().val_;
IR_ENFORCE(StringToDataType.count(datatype_token_val) > 0,
datatype_token_val + " is not defined in DataType." +
parser.GetErrorLocationInfo());
return DataTypeAttribute::get(parser.ctx,
StringToDataType[datatype_token_val]);
}
// Parse a PlaceAttribute
// PlaceAttribute := Place(cpu)|Place(gpu:0)|Place(gpu_pinned)
// |Place(xpu:0)|Place(ipu:0)|Place(:0)|undefined
PlaceAttribute PlaceAttribute::Parse(ir::IrParser &parser) { // NOLINT
std::unordered_map<std::string, phi::Place> StringToPlace{
{"cpu", phi::CPUPlace{}},
{"gpu", phi::GPUPlace{}},
{"gpu_pinned", phi::GPUPinnedPlace{}},
{"xpu", phi::XPUPlace{}},
{"ipu", phi::IPUPlace{}},
{":", phi::CustomPlace{}},
{"undefined", phi::Place{}}};
parser.ConsumeAToken("Place");
parser.ConsumeAToken("(");
std::string place_token_val = parser.ConsumeToken().val_;
IR_ENFORCE(StringToPlace.count(place_token_val) > 0,
place_token_val + " is not defined in Place." +
parser.GetErrorLocationInfo());
if (parser.PeekToken().val_ == ":") {
parser.ConsumeAToken(":");
parser.ConsumeToken();
} else if (place_token_val == ":") {
parser.ConsumeToken();
}
parser.ConsumeAToken(")");
return PlaceAttribute::get(parser.ctx, StringToPlace[place_token_val]);
}
// Parse a DataLayoutAttribute
// DataLayoutAttribute := NHWC|NCHW|Undefined(0)|ONEDNN
// |SPARSE_COO|SPARSE_CSR|NDHWC
// |NCDHW|PSTRING_UNION|STRIDED
DataLayoutAttribute DataLayoutAttribute::Parse(
ir::IrParser &parser) { // NOLINT
std::unordered_map<std::string, phi::DataLayout> StringToDataLayout{
{"NHWC", phi::DataLayout::kNHWC},
{"NCHW", phi::DataLayout::kNCHW},
{"Undefined", phi::DataLayout::kAnyLayout},
{"ONEDNN", phi::DataLayout::ONEDNN},
{"SPARSE_COO", phi::DataLayout::SPARSE_COO},
{"SPARSE_CSR", phi::DataLayout::SPARSE_CSR},
{"NDHWC", phi::DataLayout::kNDHWC},
{"NCDHW", phi::DataLayout::kNCDHW},
{"PSTRING_UNION", phi::DataLayout::PSTRING_UNION},
{"STRIDED", phi::DataLayout::STRIDED}};
std::string datalayout_token_val = parser.ConsumeToken().val_;
IR_ENFORCE(StringToDataLayout.count(datalayout_token_val) > 0,
datalayout_token_val + " is not defined in DataLayout." +
parser.GetErrorLocationInfo());
if (datalayout_token_val == "Undefined") {
parser.ConsumeAToken("(");
parser.ConsumeAToken("AnyLayout");
parser.ConsumeAToken(")");
}
return DataLayoutAttribute::get(parser.ctx,
StringToDataLayout[datalayout_token_val]);
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/ir_parser.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -34,6 +35,8 @@ class IntArrayAttribute : public ir::Attribute { ...@@ -34,6 +35,8 @@ class IntArrayAttribute : public ir::Attribute {
return storage() < right.storage(); return storage() < right.storage();
} }
static IntArrayAttribute Parse(ir::IrParser &parser); // NOLINT
const phi::IntArray &data() const; const phi::IntArray &data() const;
}; };
...@@ -68,6 +71,8 @@ class DataTypeAttribute : public ir::Attribute { ...@@ -68,6 +71,8 @@ class DataTypeAttribute : public ir::Attribute {
return storage() < right.storage(); return storage() < right.storage();
} }
static DataTypeAttribute Parse(ir::IrParser &parser); // NOLINT
phi::DataType data() const; phi::DataType data() const;
}; };
...@@ -81,6 +86,8 @@ class PlaceAttribute : public ir::Attribute { ...@@ -81,6 +86,8 @@ class PlaceAttribute : public ir::Attribute {
return storage() < right.storage(); return storage() < right.storage();
} }
static PlaceAttribute Parse(ir::IrParser &parser); // NOLINT
phi::Place data() const; phi::Place data() const;
}; };
...@@ -95,6 +102,7 @@ class DataLayoutAttribute : public ir::Attribute { ...@@ -95,6 +102,7 @@ class DataLayoutAttribute : public ir::Attribute {
return storage() < right.storage(); return storage() < right.storage();
} }
static DataLayoutAttribute Parse(ir::IrParser &parser); // NOLINT
phi::DataLayout data() const; phi::DataLayout data() const;
}; };
......
...@@ -83,9 +83,12 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { ...@@ -83,9 +83,12 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const {
} }
void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
os << "(" << attr.dialect().name();
os << '.';
if (auto int_array_attr = attr.dyn_cast<IntArrayAttribute>()) { if (auto int_array_attr = attr.dyn_cast<IntArrayAttribute>()) {
phi::IntArray data = int_array_attr.data(); phi::IntArray data = int_array_attr.data();
os << "IntArray["; os << "IntArray)"
<< "[";
const auto &inner_data = data.GetData(); const auto &inner_data = data.GetData();
ir::PrintInterleave( ir::PrintInterleave(
inner_data.begin(), inner_data.begin(),
...@@ -94,16 +97,64 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { ...@@ -94,16 +97,64 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
[&os]() { os << ","; }); [&os]() { os << ","; });
os << "]"; os << "]";
} else if (auto data_type_attr = attr.dyn_cast<DataTypeAttribute>()) { } else if (auto data_type_attr = attr.dyn_cast<DataTypeAttribute>()) {
os << data_type_attr.data(); os << "DataType)" << data_type_attr.data();
} else if (auto place_type_attr = attr.dyn_cast<PlaceAttribute>()) { } else if (auto place_type_attr = attr.dyn_cast<PlaceAttribute>()) {
os << place_type_attr.data(); os << "Place)" << place_type_attr.data();
} else if (auto data_layout_attr = attr.dyn_cast<DataLayoutAttribute>()) { } else if (auto data_layout_attr = attr.dyn_cast<DataLayoutAttribute>()) {
os << data_layout_attr.data(); os << "DataLayout)" << data_layout_attr.data();
} else { } else {
os << "<#AttrNotImplemented>"; os << "<#AttrNotImplemented>";
} }
} }
ir::Type PaddleDialect::ParseType(ir::IrParser &parser) { // NOLINT
parser.ConsumeAToken("pd.tensor");
parser.ConsumeAToken("<");
std::vector<int> dim{};
Token dim_token = parser.PeekToken();
while (dim_token.token_type_ == DIGIT) {
dim_token = parser.ConsumeToken();
dim.push_back(atoi(dim_token.val_.c_str()));
std::string peek_token_val = parser.PeekToken().val_;
if (peek_token_val[0] != 'x') {
break;
}
parser.ConsumeToken();
parser.lexer->Unget(peek_token_val.size() - 1);
if (parser.PeekToken().token_type_ != DIGIT) {
break;
}
}
phi::DDim ddim = phi::make_ddim(dim);
ir::Type dtype = parser.ParseType();
std::vector<std::vector<size_t>> lod;
std::vector<size_t> lodv;
lodv.push_back(0);
lod.push_back(lodv);
parser.ConsumeAToken(">");
return DenseTensorType::get(
parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0);
}
ir::Attribute PaddleDialect::ParseAttribute(ir::IrParser &parser) { // NOLINT
std::string type_name = parser.ConsumeToken().val_;
std::string attribute_name =
type_name.substr(type_name.find('.') + 1, std::string::npos);
parser.ConsumeAToken(")");
if (attribute_name == "IntArray") {
return IntArrayAttribute::Parse(parser);
} else if (attribute_name == "DataType") {
return DataTypeAttribute::Parse(parser);
} else if (attribute_name == "Place") {
return PlaceAttribute::Parse(parser);
} else if (attribute_name == "DataLayout") {
return DataLayoutAttribute::Parse(parser);
} else {
IR_THROW("No function to parse " + attribute_name + " exists!" +
parser.GetErrorLocationInfo());
}
}
void PaddleDialect::PrintOperation(ir::Operation *op, void PaddleDialect::PrintOperation(ir::Operation *op,
ir::IrPrinter &printer) const { ir::IrPrinter &printer) const {
if (auto if_op = op->dyn_cast<IfOp>()) { if (auto if_op = op->dyn_cast<IfOp>()) {
......
...@@ -25,6 +25,9 @@ class PaddleDialect : public ir::Dialect { ...@@ -25,6 +25,9 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; } static const char* name() { return "pd"; }
ir::Type ParseType(ir::IrParser& parser) override; // NOLINT
ir::Attribute ParseAttribute(ir::IrParser& parser) override; // NOLINT
void PrintType(ir::Type type, std::ostream& os) const override; void PrintType(ir::Type type, std::ostream& os) const override;
void PrintAttribute(ir::Attribute type, std::ostream& os) const override; void PrintAttribute(ir::Attribute type, std::ostream& os) const override;
......
...@@ -3,4 +3,8 @@ set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir") ...@@ -3,4 +3,8 @@ set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir")
file(GLOB IR_SRCS "*.cc") file(GLOB IR_SRCS "*.cc")
file(GLOB IR_PARSER_SRCS "parser/*.cc")
list(APPEND IR_SRCS ${IR_PARSER_SRCS})
ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim) ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim)
...@@ -68,6 +68,8 @@ class IR_API Attribute { ...@@ -68,6 +68,8 @@ class IR_API Attribute {
/// @param os /// @param os
void Print(std::ostream &os) const; void Print(std::ostream &os) const;
static Attribute Parse(std::istream &is, IrContext *ctx);
/// ///
/// \brief Methods for type judgment and cast. /// \brief Methods for type judgment and cast.
/// ///
......
...@@ -29,7 +29,7 @@ namespace ir { ...@@ -29,7 +29,7 @@ namespace ir {
class Operation; class Operation;
class IrPrinter; class IrPrinter;
class IrParser;
class DialectInterface; class DialectInterface;
/// ///
/// \brief Dialect can basically be understood as a namespace. In Dialect, we /// \brief Dialect can basically be understood as a namespace. In Dialect, we
...@@ -145,9 +145,21 @@ class IR_API Dialect { ...@@ -145,9 +145,21 @@ class IR_API Dialect {
IR_THROW("dialect has no registered attribute printing hook"); IR_THROW("dialect has no registered attribute printing hook");
} }
virtual Type ParseType(IrParser &parser) { // NOLINT
IR_THROW("dialect has no registered type parsing hook");
}
virtual Attribute ParseAttribute(IrParser &parser) { // NOLINT
IR_THROW("dialect has no registered attribute parsing hook");
}
virtual void PrintOperation(Operation *op, virtual void PrintOperation(Operation *op,
IrPrinter &printer) const; // NOLINT IrPrinter &printer) const; // NOLINT
virtual Operation ParseOperation(IrParser &parser) { // NOLINT
IR_THROW("dialect has no registered operation parsing hook");
}
private: private:
Dialect(const Dialect &) = delete; Dialect(const Dialect &) = delete;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parser/lexer.h"
#include "paddle/ir/core/program.h"
using OpResultMap = std::map<std::string, ir::OpResult>;
using AttributeMap = std::unordered_map<std::string, ir::Attribute>;
using OpAttributeInfoMap = std::map<std::string, std::string>;
namespace ir {
class IrParser {
public:
std::unique_ptr<Lexer> lexer;
IrContext* ctx;
OpResultMap opresultmap;
std::unique_ptr<Builder> builder;
public:
IrParser(IrContext* ctx, std::istream& is);
~IrParser() = default;
Token ConsumeToken();
Token PeekToken();
std::unique_ptr<Program> ParseProgram();
void ParseRegion(Region& region); // NOLINT
void ParseBlock(Block& block); // NOLINT
Operation* ParseOperation();
OpInfo ParseOpInfo();
std::vector<std::string> ParseOpResultList();
std::vector<OpResult> ParseOpRandList();
AttributeMap ParseAttributeMap();
std::vector<Type> ParseTypeList();
OpResult GetNullValue();
Type ParseType();
Attribute ParseAttribute();
std::string GetErrorLocationInfo();
void ConsumeAToken(std::string expect_token_val);
};
} // namespace ir
...@@ -87,22 +87,27 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { ...@@ -87,22 +87,27 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
} }
if (auto s = attr.dyn_cast<StrAttribute>()) { if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.AsString(); os << "(String)" << s.AsString();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) { } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data(); if (b.data()) {
os << "true";
} else {
os << "false";
}
} else if (auto f = attr.dyn_cast<FloatAttribute>()) { } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data(); os << "(Float)" << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) { } else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data(); os << "(Double)" << d.data();
} else if (auto i = attr.dyn_cast<Int32Attribute>()) { } else if (auto i = attr.dyn_cast<Int32Attribute>()) {
os << i.data(); os << "(Int32)" << i.data();
} else if (auto i = attr.dyn_cast<Int64Attribute>()) { } else if (auto i = attr.dyn_cast<Int64Attribute>()) {
os << i.data(); os << "(Int64)" << i.data();
} else if (auto p = attr.dyn_cast<PointerAttribute>()) { } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
os << p.data(); os << "(Pointer)" << p.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) { } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.AsVector(); const auto& vec = arr.AsVector();
os << "array["; os << "(Array)"
<< "[";
PrintInterleave( PrintInterleave(
vec.begin(), vec.begin(),
vec.end(), vec.end(),
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/ir_parser.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_type.h"
namespace ir {
IrParser::IrParser(IrContext* ctx, std::istream& is) {
lexer.reset(new Lexer{is});
this->ctx = ctx;
builder.reset(new Builder{ctx});
}
Token IrParser::ConsumeToken() {
auto token = lexer->ConsumeToken();
return token;
}
std::string IrParser::GetErrorLocationInfo() {
return "The error occurred in line " + std::to_string(lexer->GetLine()) +
", column " + std::to_string(lexer->GetColumn());
}
Token IrParser::PeekToken() {
auto token = lexer->ConsumeToken();
if (token.token_type_ != EOF_) {
lexer->Unget(token.val_.size());
}
return token;
}
void IrParser::ConsumeAToken(std::string expect_token_val) {
std::string token_val = ConsumeToken().val_;
IR_ENFORCE(token_val == expect_token_val,
"The token value of expectation is " + expect_token_val + " ,not" +
token_val + "." + GetErrorLocationInfo());
}
// Type := BuiltinType | OtherDialectsDefineType
// BuiltinType := <<NULL TYPE>> | bf16 | f16 | f32 | f64
// := | b | i8 | u8 | i16 | i32 | i64 | index | c64
// := | c128 | VectorType
// VectorType := '[' Type(,Type)* ']'
Type IrParser::ParseType() {
Token type_token = PeekToken();
std::string type_val = type_token.val_;
if (type_val == "<<NULL TYPE>>") {
ConsumeToken();
return Type(nullptr);
} else if (type_val == "bf16") {
ConsumeToken();
return builder->bfloat16_type();
} else if (type_val == "f16") {
ConsumeToken();
return builder->bfloat16_type();
} else if (type_val == "f32") {
ConsumeToken();
return builder->float32_type();
} else if (type_val == "f64") {
ConsumeToken();
return builder->float64_type();
} else if (type_val == "b") {
ConsumeToken();
return builder->bool_type();
} else if (type_val == "i8") {
ConsumeToken();
return builder->int8_type();
} else if (type_val == "u8") {
ConsumeToken();
return builder->uint8_type();
} else if (type_val == "i16") {
ConsumeToken();
return builder->int16_type();
} else if (type_val == "i32") {
ConsumeToken();
return Int32Type::get(ctx);
} else if (type_val == "i64") {
ConsumeToken();
return Int64Type::get(ctx);
} else if (type_val == "index") {
ConsumeToken();
return IndexType::get(ctx);
} else if (type_val == "c64") {
ConsumeToken();
return builder->complex64_type();
} else if (type_val == "c128") {
ConsumeToken();
return builder->complex128_type();
} else if (type_val == "vec") {
ConsumeAToken("vec");
ConsumeAToken("[");
std::vector<Type> vec_type;
Token vec_type_token = PeekToken();
while (vec_type_token.val_ != "]") {
Type cur_type = ParseType();
vec_type.push_back(cur_type);
vec_type_token = ConsumeToken();
}
return VectorType::get(ctx, vec_type);
} else {
IR_ENFORCE(type_val.find('.') != std::string::npos,
"No function parsing " + type_val + " exists!" +
GetErrorLocationInfo());
auto dialect_name = type_val.substr(0, type_val.find('.'));
auto dialect = ctx->GetRegisteredDialect(dialect_name);
return dialect->ParseType(*this);
}
}
// Attribute := BuiltinAttribute | OtherDialectsDefineAttribute
// BuiltinAttribute := Bool | String | Float | Double | Int32 |
// := | Int64 | Pointer | ArrayAttribute
// ArrayAttribute := '[' Atribute(,Attribute)* ']'
Attribute IrParser::ParseAttribute() {
auto parenthesis_token = ConsumeToken();
if (parenthesis_token.val_ == "true" || parenthesis_token.val_ == "false") {
return builder->bool_attr(parenthesis_token.val_ == "true");
}
std::string attribute_type = PeekToken().val_;
if (attribute_type == "String") {
ConsumeAToken("String");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->str_attr(val);
} else if (attribute_type == "Float") {
ConsumeAToken("Float");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->float_attr(atof(val.c_str()));
} else if (attribute_type == "Double") {
ConsumeAToken("Double");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->double_attr(atof(val.c_str()));
} else if (attribute_type == "Int32") {
ConsumeAToken("Int32");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->int32_attr(atoi(val.c_str()));
} else if (attribute_type == "Int64") {
ConsumeAToken("Int64");
ConsumeAToken(")");
std::string val = ConsumeToken().val_;
return builder->int64_attr(atoll(val.c_str()));
} else if (attribute_type == "Pointer") {
IR_THROW("This attribute is not currently supported by parser");
} else if (attribute_type == "Array") {
ConsumeAToken("Array");
ConsumeAToken(")");
ConsumeAToken("[");
std::vector<Attribute> array_attribute;
while (PeekToken().val_ != "]") {
array_attribute.push_back(ParseAttribute());
if (PeekToken().val_ == "]") break;
ConsumeAToken(",");
}
ConsumeAToken("]");
return builder->array_attr(array_attribute);
} else {
IR_ENFORCE(attribute_type.find('.') != std::string::npos,
"No function parsing " + attribute_type + " exists!" +
GetErrorLocationInfo());
auto dialect_name = attribute_type.substr(0, attribute_type.find('.'));
auto dialect = ctx->GetRegisteredDialect(dialect_name);
return dialect->ParseAttribute(*this);
}
}
// Program := [ParameterList]ModuleOp
// ModuleOp := Region
std::unique_ptr<Program> IrParser::ParseProgram() {
std::unique_ptr<Program> program(new Program{ctx});
auto top_level_op = program->module_op();
auto& region = top_level_op->region(0);
ParseRegion(region);
return program;
}
// Region := Block
void IrParser::ParseRegion(Region& region) { // NOLINT
ParseBlock(*region.front());
IR_ENFORCE(PeekToken().val_ != "{",
"Only one block in a region is supported");
}
// Block := "{" {Operation} "}"
void IrParser::ParseBlock(Block& block) { // NOLINT
ConsumeAToken("{");
while (PeekToken().val_ != "}") {
auto op = ParseOperation();
block.push_back(op);
}
ConsumeAToken("}");
}
// Operation := OpResultList ":=" Opname "(" OprandList ? ")" AttributeMap ":"
// FunctionType
// FunctionType := "(" TypeList ")" "->" TypeList
Operation* IrParser::ParseOperation() {
std::vector<std::string> opresultindex = ParseOpResultList();
ConsumeAToken("=");
OpInfo opinfo = ParseOpInfo();
std::vector<OpResult> inputs = ParseOpRandList();
ir::AttributeMap attributeMap = ParseAttributeMap();
ConsumeAToken(":");
ConsumeAToken("(");
ParseTypeList();
ConsumeAToken(")");
ConsumeAToken("->");
std::vector<Type> type_vector = ParseTypeList();
Operation* op =
Operation::Create(inputs, attributeMap, type_vector, opinfo, 0);
for (uint32_t i = 0; i < op->num_results(); i++) {
std::string key_t = opresultindex[i];
opresultmap[key_t] = op->result(i);
}
return op;
}
// OpResultList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<std::string> IrParser::ParseOpResultList() {
std::vector<std::string> opresultindex{};
ConsumeAToken("(");
Token index_token = ConsumeToken();
while (index_token.val_ != ")") {
if (index_token.token_type_ == NULL_) {
opresultindex.push_back("null");
} else {
std::string str = index_token.val_;
opresultindex.push_back(str);
}
if (ConsumeToken().val_ == ")") break;
index_token = ConsumeToken();
}
return opresultindex;
}
// OpName := "\"" StringIdentifer "." StringIdentifer "\""
OpInfo IrParser::ParseOpInfo() {
Token opname_token = ConsumeToken();
std::string opname =
opname_token.val_.substr(1, opname_token.val_.size() - 2);
return ctx->GetRegisteredOpInfo(opname);
}
// OprandList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<OpResult> IrParser::ParseOpRandList() {
ConsumeAToken("(");
std::vector<OpResult> inputs{};
Token ind_token = ConsumeToken();
while (ind_token.val_ != ")") {
std::string t = "";
if (ind_token.token_type_ == NULL_) {
inputs.push_back(GetNullValue());
} else {
t = ind_token.val_;
inputs.push_back(opresultmap[t]);
}
Token token = ConsumeToken();
if (token.val_ == ")") {
break;
}
ind_token = ConsumeToken();
}
return inputs;
}
// AttributeMap := "{" AttributeEntry,(,AttributeEntry)* "}"
// AttributeEntry := StringIdentifer:Attribute
AttributeMap IrParser::ParseAttributeMap() {
AttributeMap attribute_map{};
ConsumeAToken("{");
Token key_token = ConsumeToken();
while (key_token.val_ != "}") {
ConsumeAToken(":");
attribute_map[key_token.val_] = ParseAttribute();
std::string token_val = ConsumeToken().val_;
if (token_val == "}") {
break;
} else if (token_val == ",") {
key_token = ConsumeToken();
} else {
IR_ENFORCE((token_val == "}") || (token_val == ","),
"The token value of expectation is } or , , not " + token_val +
"." + GetErrorLocationInfo());
}
}
return attribute_map;
}
// TypeList := Type(,Type)*
std::vector<Type> IrParser::ParseTypeList() {
std::vector<Type> type_vector{};
while (PeekToken().val_ != "(" && PeekToken().val_ != "}" &&
PeekToken().val_ != ")") {
type_vector.push_back(ParseType());
if (PeekToken().val_ == "}" || PeekToken().val_ == "(" ||
PeekToken().val_ == ")" || PeekToken().token_type_ == EOF_)
break;
ConsumeAToken(",");
}
return type_vector;
}
OpResult IrParser::GetNullValue() {
Value* v = new Value{nullptr};
OpResult* opresult = static_cast<OpResult*>(v);
return *opresult;
}
Attribute Attribute::Parse(std::istream& is, IrContext* ctx) {
IrParser parser(ctx, is);
return parser.ParseAttribute();
}
Type Type::Parse(std::istream& is, IrContext* ctx) {
IrParser parser(ctx, is);
return parser.ParseType();
}
std::unique_ptr<Program> Program::Parse(std::istream& is, IrContext* ctx) {
IrParser parser(ctx, is);
return parser.ParseProgram();
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/core/parser/lexer.h"
Token Lexer::ConsumeToken() {
SkipWhitespace();
if (auto token = LexIdentifer()) {
return *token;
} else if (auto token = LexNumberOrArraow()) {
return *token;
} else if (auto token = LexEndTagOrNullVal()) {
return *token;
} else if (auto token = LexValueId()) {
return *token;
} else if (auto token = LexOpName()) {
return *token;
} else if (auto token = LexEOF()) {
return *token;
} else {
return Token{"Error", NULL_};
}
}
char Lexer::GetChar() {
char c = is.get();
if (c == '\n') {
line++;
column = 1;
} else {
column++;
}
return c;
}
size_t Lexer::GetColumn() { return column; }
size_t Lexer::GetLine() { return line; }
void Lexer::SkipWhitespace() {
while (IsSpace(is.peek())) {
GetChar();
}
}
std::unique_ptr<Token> Lexer::LexIdentifer() {
if ((!isalpha(is.peek()) && is.peek() != '_') || IsEndTag(is.peek())) {
return nullptr;
}
std::string token_identifier = "";
while (isalnum(is.peek()) || is.peek() == '_' || is.peek() == '.') {
token_identifier += GetChar();
}
std::unique_ptr<Token> token(new Token{token_identifier, IDENTIFER});
return token;
}
std::unique_ptr<Token> Lexer::LexNumberOrArraow() {
if (!isdigit(is.peek()) && is.peek() != '-') {
return nullptr;
}
std::string token_digit = "";
token_digit += GetChar();
if (token_digit[0] == '-' && is.peek() == '>') {
GetChar();
std::unique_ptr<Token> arrow_token(new Token{"->", ARRAOW});
return arrow_token;
}
while (isdigit(is.peek())) {
token_digit += GetChar();
}
if (is.peek() == '.') {
token_digit += GetChar();
while (isdigit(is.peek())) {
token_digit += GetChar();
}
}
if (is.peek() == 'e') {
token_digit += GetChar();
if (is.peek() == '+' || is.peek() == '-') {
token_digit += GetChar();
}
while (isdigit(is.peek())) {
token_digit += GetChar();
}
std::unique_ptr<Token> sdigit_token(new Token{token_digit, SDIGIT});
return sdigit_token;
}
std::unique_ptr<Token> digit_token(new Token{token_digit, DIGIT});
return digit_token;
}
std::unique_ptr<Token> Lexer::LexEndTagOrNullVal() {
if (!IsEndTag(is.peek())) {
return nullptr;
}
std::string token_end = "";
token_end += GetChar();
if ((token_end[0] == '<' && (is.peek() != '<' && is.peek() != '#')) ||
token_end[0] != '<') {
std::unique_ptr<Token> endtag_token(new Token{token_end, ENDTAG});
return endtag_token;
}
if (is.peek() == '<') {
std::string token_null_val = "";
GetChar();
while (is.peek() != '>') {
token_null_val += GetChar();
}
GetChar();
GetChar();
std::unique_ptr<Token> null_token(
new Token{"<<" + token_null_val + ">>", NULL_});
return null_token;
} else {
std::string token_attrnull = "";
while (is.peek() != '>') {
token_attrnull += GetChar();
}
GetChar();
std::unique_ptr<Token> null_token(
new Token{"<" + token_attrnull + ">", NULL_});
return null_token;
}
}
std::unique_ptr<Token> Lexer::LexValueId() {
if (is.peek() != '%') {
return nullptr;
}
std::string token_valueid = "";
token_valueid += GetChar();
while (isdigit(is.peek())) {
token_valueid += GetChar();
}
std::unique_ptr<Token> valueid_token(new Token{token_valueid, VALUEID});
return valueid_token;
}
std::unique_ptr<Token> Lexer::LexEOF() {
if (is.peek() == EOF) {
std::unique_ptr<Token> eof_token(new Token{"LEX_DOWN", EOF_});
return eof_token;
} else {
return nullptr;
}
}
std::unique_ptr<Token> Lexer::LexOpName() {
if (is.peek() != '"') {
return nullptr;
}
GetChar();
std::string token_opname = "";
while (is.peek() != '"') {
token_opname += GetChar();
}
GetChar();
std::unique_ptr<Token> opname_token(
new Token{"\"" + token_opname + "\"", OPNAME});
return opname_token;
}
bool Lexer::IsSpace(char c) {
return c == ' ' || c == '\n' || c == '\t' || c == '\f';
}
bool Lexer::IsEndTag(char c) {
return c == '{' || c == '}' || c == '(' || c == ')' || c == ':' || c == '>' ||
c == ',' || c == ']' || c == '[' || c == '+' || c == '=' || c == '<';
}
void Lexer::Unget(const int len) {
if (is.eof()) {
is.clear();
}
column -= len;
is.seekg(-len, std::ios::cur);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <istream>
#include <memory>
#include "paddle/ir/core/parser/token.h"
class Lexer {
private:
std::istream& is;
size_t line = 1;
size_t column = 1;
public:
explicit Lexer(std::istream& is) : is(is) {}
~Lexer() = default;
Token ConsumeToken();
std::unique_ptr<Token> LexIdentifer();
std::unique_ptr<Token> LexNumberOrArraow();
std::unique_ptr<Token> LexEndTagOrNullVal();
std::unique_ptr<Token> LexValueId();
std::unique_ptr<Token> LexEOF();
std::unique_ptr<Token> LexOpName();
char GetChar();
void SkipWhitespace();
bool IsEndTag(char);
bool IsSpace(char);
size_t GetLine();
size_t GetColumn();
void Unget(const int len);
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
enum Token_type {
EOF_ = -1,
IDENTIFER = 0,
DIGIT = 1,
SDIGIT = 2,
ENDTAG = 3,
VALUEID = 4,
OPNAME = 5,
ARRAOW = 6,
NULL_ = 7,
};
struct Token {
public:
std::string val_;
Token_type token_type_;
Token() = default;
Token(const std::string& val, Token_type token_type) {
val_ = val;
token_type_ = token_type;
}
};
...@@ -52,6 +52,8 @@ class IR_API Program { ...@@ -52,6 +52,8 @@ class IR_API Program {
void Print(std::ostream& os) const; void Print(std::ostream& os) const;
static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);
Block* block() { return module_.block(); } Block* block() { return module_.block(); }
const Block* block() const { return module_op().block(); } const Block* block() const { return module_op().block(); }
......
...@@ -84,6 +84,8 @@ class IR_API Type { ...@@ -84,6 +84,8 @@ class IR_API Type {
void Print(std::ostream &os) const; void Print(std::ostream &os) const;
static Type Parse(std::istream &is, IrContext *ctx);
/// ///
/// \brief Enable hashing Type. /// \brief Enable hashing Type.
/// ///
......
...@@ -65,6 +65,9 @@ file( ...@@ -65,6 +65,9 @@ file(
${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog ${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog
EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50) EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50)
copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt
${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt)
cc_test_old( cc_test_old(
program_translator_test program_translator_test
SRCS SRCS
...@@ -75,6 +78,24 @@ cc_test_old( ...@@ -75,6 +78,24 @@ cc_test_old(
pd_dialect pd_dialect
ir) ir)
cc_test_old(
add_dialect_parser_test
SRCS
add_dialect_parser_test.cc
DEPS
gtest
pd_dialect
ir)
cc_test_old(
ir_parser_test
SRCS
ir_parser_test.cc
DEPS
gtest
pd_dialect
ir)
cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest ir) cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest ir)
cc_test_old( cc_test_old(
ir_op_yaml_info_parser_test ir_op_yaml_info_parser_test
......
//CHECK attribute
(String)sdfgs.sdsd
//CHECK type
f32
//CHECK type
pd.tensor<256xf32>
//CHECK program
{
(%0) = "builtin.get_parameter" () {parameter_name:(String)conv2d_0.w_0} : () -> pd.tensor<64x3x7x7xf32>
(%1) = "pd.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:(String)data,stop_gradient:(Array)[true]} : () -> pd.tensor<-1x3x224x224xf32>
(%2) = "pd.conv2d" (%1, %0) {data_format:(String)NCHW,dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:(String)EXPLICIT,paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd.tensor<-1x3x224x224xf32>, pd.tensor<64x3x7x7xf32>) -> pd.tensor<-1x64x112x112xf32>
}
//CHECK attribute
(Array)[(pd.DataType)bool,(pd.DataType)float32,(pd.DataType)float64,
(pd.DataType)complex64,(pd.DataType)complex128,(pd.DataType)Undefined,
(pd.DataType)Undefined,(pd.DataType)Undefined,(pd.DataType)Undefined,
(pd.DataType)bfloat16,(pd.DataType)uint8,(pd.DataType)uint32,(pd.DataType)int8,
(pd.DataType)uint16,(pd.DataType)int16,(pd.DataType)int32,(pd.DataType)uint64,(pd.DataType)int64]
//CHECK attribute
(Array)[(pd.Place)Place(gpu:0),(pd.Place)Place(gpu_pinned),(pd.Place)Place(gpu_pinned),
(pd.Place)Place(xpu:0),(pd.Place)Place(ipu:0),(pd.Place)Place(:0),(pd.Place)Place(cpu)]
//CHECK attribute
(Array)[(pd.DataLayout)NHWC,(pd.DataLayout)STRIDED,(pd.DataLayout)NCHW,(pd.DataLayout)Undefined(AnyLayout),
(pd.DataLayout)ONEDNN,(pd.DataLayout)SPARSE_COO,(pd.DataLayout)SPARSE_CSR,(pd.DataLayout)NDHWC,(pd.DataLayout)NCDHW,
(pd.DataLayout)PSTRING_UNION]
//CHECK attribute
(Array)[(Double)1,(Int64)0,(String)1]
//CHECK type
vec[bf16,f64,b,i8,u8,i16,c64,c128]
//CHECK attribute
(String)1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute_storage.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_parser.h"
#include "paddle/ir/core/utils.h"
using PaddleDialect = paddle::dialect::PaddleDialect;
using AttributeStorage = ir::AttributeStorage;
class TestParserDialect : public ir::Dialect {
public:
explicit TestParserDialect(ir::IrContext* context);
static const char* name() { return "tp"; }
void PrintAttribute(ir::Attribute attr, std::ostream& os) const;
ir::Attribute ParseAttribute(ir::IrParser& parser); // NOLINT
private:
void initialize();
};
IR_DECLARE_EXPLICIT_TYPE_ID(TestParserDialect);
IR_DEFINE_EXPLICIT_TYPE_ID(TestParserDialect);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(CharAttributeStorage, char);
class CharAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CharAttribute, CharAttributeStorage);
char data() const;
static CharAttribute Parse(ir::IrParser& parser) { // NOLINT
std::string char_val = parser.ConsumeToken().val_;
return CharAttribute::get(parser.ctx, char_val[0]);
}
};
IR_DECLARE_EXPLICIT_TYPE_ID(CharAttribute);
IR_DEFINE_EXPLICIT_TYPE_ID(CharAttribute);
void TestParserDialect::initialize() { RegisterAttributes<CharAttribute>(); }
char CharAttribute::data() const { return storage()->data(); }
TestParserDialect::TestParserDialect(ir::IrContext* context)
: ir::Dialect(name(), context, ir::TypeId::get<TestParserDialect>()) {
initialize();
}
void TestParserDialect::PrintAttribute(ir::Attribute attr,
std::ostream& os) const {
auto byte_attr = attr.dyn_cast<CharAttribute>();
os << "(tp.char)" << byte_attr.data();
}
ir::Attribute TestParserDialect::ParseAttribute(
ir::IrParser& parser) { // NOLINT
std::string type_name = parser.ConsumeToken().val_;
std::string parenthesis_token_val = parser.ConsumeToken().val_;
IR_ENFORCE(parenthesis_token_val == ")",
"The token value of expectation is ), not " +
parenthesis_token_val + "." + parser.GetErrorLocationInfo());
return CharAttribute::Parse(parser);
}
TEST(IrParserTest, AddAttribute) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
ctx->GetOrRegisterDialect<TestParserDialect>();
std::string op_str =
" (%0) = \"builtin.get_parameter\" () "
"{parameter_name:(String)conv2d_0.w_0,test:(tp.char)a} : () -> "
"pd.tensor<64x3x7x7xf32>";
std::stringstream ss;
ss << op_str;
ir::IrParser* parser = new ir::IrParser(ctx, ss);
ir::Operation* op = parser->ParseOperation();
std::stringstream ssp;
op->Print(ssp);
delete parser;
EXPECT_TRUE(ssp.str() == ss.str());
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/ir_adaptor/translator/translate.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute_storage.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_parser.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/utils.h"
using PaddleDialect = paddle::dialect::PaddleDialect;
using AttributeStorage = ir::AttributeStorage;
enum TestType {
AttributeTest = 0,
TypeTest = 1,
ProgramTest = 2,
};
class TestTask {
public:
TestType test_type;
std::string test_info;
public:
TestTask(TestType test_type, std::string test_info) {
this->test_info = test_info;
this->test_type = test_type;
}
};
class ParserTest {
private:
std::ifstream& test_text;
public:
explicit ParserTest(std::ifstream& test_text) : test_text(test_text) {}
TestTask* GetTestTask();
bool ConsumeTestTask(TestTask* test_task, ir::IrContext* ctx);
};
TestTask* ParserTest::GetTestTask() {
if (test_text.peek() == EOF) {
return nullptr;
}
std::string test_info;
while (test_text.peek() != '/') {
test_text.get();
}
while (test_text.peek() != ' ') {
test_text.get();
}
test_text.get();
std::string test_type_info;
while (test_text.peek() != '\n') {
test_type_info += test_text.get();
}
test_text.get();
while (test_text.peek() != '/' && test_text.peek() != EOF) {
test_info += test_text.get();
}
if (test_type_info == "attribute") {
return new TestTask(AttributeTest, test_info);
} else if (test_type_info == "type") {
return new TestTask(TypeTest, test_info);
} else if (test_type_info == "program") {
return new TestTask(ProgramTest, test_info);
}
return nullptr;
}
bool ParserTest::ConsumeTestTask(TestTask* test_task, ir::IrContext* ctx) {
std::string test_info = test_task->test_info;
TestType test_type = test_task->test_type;
std::unique_ptr<ir::IrPrinter> printer;
std::unique_ptr<ir::IrParser> parser;
std::stringstream is(test_info);
parser.reset(new ir::IrParser(ctx, is));
std::vector<std::string> before_parser_tokens;
while (parser->PeekToken().token_type_ != EOF_) {
before_parser_tokens.push_back(parser->ConsumeToken().val_);
}
std::stringstream is_par(test_info);
std::stringstream os;
if (test_type == AttributeTest) {
auto attr = ir::Attribute::Parse(is_par, ctx);
attr.Print(os);
} else if (test_type == ProgramTest) {
auto program = ir::Program::Parse(is_par, ctx);
program->Print(os);
} else if (test_type == TypeTest) {
auto type = ir::Type::Parse(is_par, ctx);
type.Print(os);
}
parser.reset(new ir::IrParser(ctx, os));
std::vector<std::string> after_parser_tokens;
while (parser->PeekToken().token_type_ != EOF_) {
auto str = parser->ConsumeToken().val_;
after_parser_tokens.push_back(str);
}
delete test_task;
if (after_parser_tokens.size() != before_parser_tokens.size()) {
return false;
}
for (size_t i = 0; i < after_parser_tokens.size(); i++) {
if (after_parser_tokens[i] != before_parser_tokens[i]) {
return false;
}
}
return true;
}
TEST(IrParserTest, TestParserByFile) {
ir::IrContext* ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
std::ifstream is("TestParserText.txt");
EXPECT_TRUE(is.is_open());
ParserTest parser_test(is);
bool is_test = false;
while (TestTask* test_task = parser_test.GetTestTask()) {
is_test = true;
bool ans = parser_test.ConsumeTestTask(test_task, ctx);
EXPECT_TRUE(ans);
}
is.close();
EXPECT_TRUE(is_test);
}
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_parser.h"
#include "paddle/ir/core/ir_printer.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
using PaddleDialect = paddle::dialect::PaddleDialect; using PaddleDialect = paddle::dialect::PaddleDialect;
...@@ -107,3 +109,37 @@ TEST(RegisterInfoTest, MainProgram) { ...@@ -107,3 +109,37 @@ TEST(RegisterInfoTest, MainProgram) {
EXPECT_EQ(unregistered_ops.size(), 1u); EXPECT_EQ(unregistered_ops.size(), 1u);
EXPECT_EQ(unregistered_ops[0], "something must not be registered"); EXPECT_EQ(unregistered_ops[0], "something must not be registered");
} }
TEST(IrParserTest, MainProgram) {
auto p = load_from_file("resnet50_main.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
std::stringstream ss;
program->Print(ss);
std::unique_ptr<ir::Program> parser_program = ir::Program::Parse(ss, ctx);
std::stringstream ssp;
parser_program->Print(ssp);
EXPECT_TRUE(ssp.str() == ss.str());
}
TEST(IrParserTest, StartupProgram) {
auto p = load_from_file("resnet50_startup.prog");
EXPECT_EQ(p.Size(), 1u);
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
auto program = paddle::TranslateLegacyProgramToProgram(p);
std::stringstream ss;
program->Print(ss);
std::unique_ptr<ir::Program> parser_program = ir::Program::Parse(ss, ctx);
std::stringstream ssp;
parser_program->Print(ssp);
EXPECT_TRUE(ssp.str() == ss.str());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册