From a2d61455d258e974502f8d3df4dadff1cb462f51 Mon Sep 17 00:00:00 2001 From: xingmingyyj <135400902+xingmingyyj@users.noreply.github.com> Date: Fri, 8 Sep 2023 19:58:33 +0800 Subject: [PATCH] [NewIR] Add parser to deserialize (#55695) * add parser * add parser * add parser * add parser * add parser * add parser * add parser * Update test/cpp/ir/core/ir_parser_test.cc Co-authored-by: kangguangli * add parser * add parser * add parser * Update test/cpp/ir/core/program_translator_test.cc Co-authored-by: kangguangli * Update test/cpp/ir/core/program_translator_test.cc Co-authored-by: kangguangli * Update dialect.h * add parser * add parser * Update CMakeLists.txt * add parser --------- Co-authored-by: kangguangli --- .../dialect/paddle_dialect/ir/pd_attribute.cc | 105 +++++- .../dialect/paddle_dialect/ir/pd_attribute.h | 8 + .../dialect/paddle_dialect/ir/pd_dialect.cc | 59 ++- .../ir/dialect/paddle_dialect/ir/pd_dialect.h | 3 + paddle/ir/core/CMakeLists.txt | 4 + paddle/ir/core/attribute.h | 2 + paddle/ir/core/dialect.h | 14 +- paddle/ir/core/ir_parser.h | 71 ++++ paddle/ir/core/ir_printer.cc | 21 +- paddle/ir/core/parser/ir_parser.cc | 351 ++++++++++++++++++ paddle/ir/core/parser/lexer.cc | 193 ++++++++++ paddle/ir/core/parser/lexer.h | 44 +++ paddle/ir/core/parser/token.h | 39 ++ paddle/ir/core/program.h | 2 + paddle/ir/core/type.h | 2 + test/cpp/ir/core/CMakeLists.txt | 21 ++ test/cpp/ir/core/TestParserText.txt | 43 +++ test/cpp/ir/core/add_dialect_parser_test.cc | 113 ++++++ test/cpp/ir/core/ir_parser_test.cc | 153 ++++++++ test/cpp/ir/core/program_translator_test.cc | 36 ++ 20 files changed, 1270 insertions(+), 14 deletions(-) create mode 100644 paddle/ir/core/ir_parser.h create mode 100644 paddle/ir/core/parser/ir_parser.cc create mode 100644 paddle/ir/core/parser/lexer.cc create mode 100644 paddle/ir/core/parser/lexer.h create mode 100644 paddle/ir/core/parser/token.h create mode 100644 test/cpp/ir/core/TestParserText.txt create mode 100644 test/cpp/ir/core/add_dialect_parser_test.cc create mode 100644 test/cpp/ir/core/ir_parser_test.cc diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc index 3b566edf03c..72cc98447e1 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.cc @@ -16,7 +16,7 @@ namespace paddle { namespace dialect { -const phi::IntArray& IntArrayAttribute::data() const { +const phi::IntArray &IntArrayAttribute::data() const { return storage()->GetAsKey(); } @@ -48,6 +48,109 @@ phi::Scalar ScalarAttribute::data() { } } +IntArrayAttribute IntArrayAttribute::Parse(ir::IrParser &parser) { // NOLINT + Token buket_token = parser.ConsumeToken(); + std::vector vec{}; + while (parser.PeekToken().val_ != "]") { + Token val_token = parser.ConsumeToken(); + vec.push_back(atoll(val_token.val_.c_str())); + if (parser.PeekToken().val_ == "]") break; + parser.ConsumeToken(); + } + parser.ConsumeToken(); + return IntArrayAttribute::get(parser.ctx, vec); +} + +// Parse a DataTypeAttribute +// DataTypeAttribute := bool|uint8|int8|uint16|int16|uint32 +// |int32|uint64|int64|float32|complex64 +// |complex128|Undefined|psting|flaot16 +// |bfloat16|num_data_types|all_dtype +DataTypeAttribute DataTypeAttribute::Parse(ir::IrParser &parser) { // NOLINT + std::unordered_map StringToDataType{ + {"bool", phi::DataType::BOOL}, + {"uint8", phi::DataType::UINT8}, + {"int8", phi::DataType::INT8}, + {"uint16", phi::DataType::UINT16}, + {"int16", phi::DataType::INT16}, + {"uint32", phi::DataType::UINT32}, + {"int32", phi::DataType::INT32}, + {"uint64", phi::DataType::UINT64}, + {"int64", phi::DataType::INT64}, + {"float32", phi::DataType::FLOAT32}, + {"complex64", phi::DataType::COMPLEX64}, + {"complex128", phi::DataType::COMPLEX128}, + {"Undefined", phi::DataType::UNDEFINED}, + {"psting", phi::DataType::PSTRING}, + {"float16", phi::DataType::FLOAT16}, + {"bfloat16", phi::DataType::BFLOAT16}, + {"float64", phi::DataType::FLOAT64}}; + std::string datatype_token_val = parser.ConsumeToken().val_; + IR_ENFORCE(StringToDataType.count(datatype_token_val) > 0, + datatype_token_val + " is not defined in DataType." + + parser.GetErrorLocationInfo()); + return DataTypeAttribute::get(parser.ctx, + StringToDataType[datatype_token_val]); +} + +// Parse a PlaceAttribute +// PlaceAttribute := Place(cpu)|Place(gpu:0)|Place(gpu_pinned) +// |Place(xpu:0)|Place(ipu:0)|Place(:0)|undefined +PlaceAttribute PlaceAttribute::Parse(ir::IrParser &parser) { // NOLINT + std::unordered_map StringToPlace{ + {"cpu", phi::CPUPlace{}}, + {"gpu", phi::GPUPlace{}}, + {"gpu_pinned", phi::GPUPinnedPlace{}}, + {"xpu", phi::XPUPlace{}}, + {"ipu", phi::IPUPlace{}}, + {":", phi::CustomPlace{}}, + {"undefined", phi::Place{}}}; + parser.ConsumeAToken("Place"); + parser.ConsumeAToken("("); + std::string place_token_val = parser.ConsumeToken().val_; + IR_ENFORCE(StringToPlace.count(place_token_val) > 0, + place_token_val + " is not defined in Place." + + parser.GetErrorLocationInfo()); + if (parser.PeekToken().val_ == ":") { + parser.ConsumeAToken(":"); + parser.ConsumeToken(); + } else if (place_token_val == ":") { + parser.ConsumeToken(); + } + parser.ConsumeAToken(")"); + return PlaceAttribute::get(parser.ctx, StringToPlace[place_token_val]); +} + +// Parse a DataLayoutAttribute +// DataLayoutAttribute := NHWC|NCHW|Undefined(0)|ONEDNN +// |SPARSE_COO|SPARSE_CSR|NDHWC +// |NCDHW|PSTRING_UNION|STRIDED +DataLayoutAttribute DataLayoutAttribute::Parse( + ir::IrParser &parser) { // NOLINT + std::unordered_map StringToDataLayout{ + {"NHWC", phi::DataLayout::kNHWC}, + {"NCHW", phi::DataLayout::kNCHW}, + {"Undefined", phi::DataLayout::kAnyLayout}, + {"ONEDNN", phi::DataLayout::ONEDNN}, + {"SPARSE_COO", phi::DataLayout::SPARSE_COO}, + {"SPARSE_CSR", phi::DataLayout::SPARSE_CSR}, + {"NDHWC", phi::DataLayout::kNDHWC}, + {"NCDHW", phi::DataLayout::kNCDHW}, + {"PSTRING_UNION", phi::DataLayout::PSTRING_UNION}, + {"STRIDED", phi::DataLayout::STRIDED}}; + std::string datalayout_token_val = parser.ConsumeToken().val_; + IR_ENFORCE(StringToDataLayout.count(datalayout_token_val) > 0, + datalayout_token_val + " is not defined in DataLayout." + + parser.GetErrorLocationInfo()); + if (datalayout_token_val == "Undefined") { + parser.ConsumeAToken("("); + parser.ConsumeAToken("AnyLayout"); + parser.ConsumeAToken(")"); + } + return DataLayoutAttribute::get(parser.ctx, + StringToDataLayout[datalayout_token_val]); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h index ed1f84a56c5..e1d3daab719 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h @@ -18,6 +18,7 @@ #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/ir_parser.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/enforce.h" @@ -34,6 +35,8 @@ class IntArrayAttribute : public ir::Attribute { return storage() < right.storage(); } + static IntArrayAttribute Parse(ir::IrParser &parser); // NOLINT + const phi::IntArray &data() const; }; @@ -68,6 +71,8 @@ class DataTypeAttribute : public ir::Attribute { return storage() < right.storage(); } + static DataTypeAttribute Parse(ir::IrParser &parser); // NOLINT + phi::DataType data() const; }; @@ -81,6 +86,8 @@ class PlaceAttribute : public ir::Attribute { return storage() < right.storage(); } + static PlaceAttribute Parse(ir::IrParser &parser); // NOLINT + phi::Place data() const; }; @@ -95,6 +102,7 @@ class DataLayoutAttribute : public ir::Attribute { return storage() < right.storage(); } + static DataLayoutAttribute Parse(ir::IrParser &parser); // NOLINT phi::DataLayout data() const; }; diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc index e07075a2c02..82169dafc59 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.cc @@ -83,9 +83,12 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { } void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { + os << "(" << attr.dialect().name(); + os << '.'; if (auto int_array_attr = attr.dyn_cast()) { phi::IntArray data = int_array_attr.data(); - os << "IntArray["; + os << "IntArray)" + << "["; const auto &inner_data = data.GetData(); ir::PrintInterleave( inner_data.begin(), @@ -94,16 +97,64 @@ void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { [&os]() { os << ","; }); os << "]"; } else if (auto data_type_attr = attr.dyn_cast()) { - os << data_type_attr.data(); + os << "DataType)" << data_type_attr.data(); } else if (auto place_type_attr = attr.dyn_cast()) { - os << place_type_attr.data(); + os << "Place)" << place_type_attr.data(); } else if (auto data_layout_attr = attr.dyn_cast()) { - os << data_layout_attr.data(); + os << "DataLayout)" << data_layout_attr.data(); } else { os << "<#AttrNotImplemented>"; } } +ir::Type PaddleDialect::ParseType(ir::IrParser &parser) { // NOLINT + parser.ConsumeAToken("pd.tensor"); + parser.ConsumeAToken("<"); + std::vector dim{}; + Token dim_token = parser.PeekToken(); + while (dim_token.token_type_ == DIGIT) { + dim_token = parser.ConsumeToken(); + dim.push_back(atoi(dim_token.val_.c_str())); + std::string peek_token_val = parser.PeekToken().val_; + if (peek_token_val[0] != 'x') { + break; + } + parser.ConsumeToken(); + parser.lexer->Unget(peek_token_val.size() - 1); + if (parser.PeekToken().token_type_ != DIGIT) { + break; + } + } + phi::DDim ddim = phi::make_ddim(dim); + ir::Type dtype = parser.ParseType(); + std::vector> lod; + std::vector lodv; + lodv.push_back(0); + lod.push_back(lodv); + parser.ConsumeAToken(">"); + return DenseTensorType::get( + parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); +} + +ir::Attribute PaddleDialect::ParseAttribute(ir::IrParser &parser) { // NOLINT + std::string type_name = parser.ConsumeToken().val_; + std::string attribute_name = + type_name.substr(type_name.find('.') + 1, std::string::npos); + parser.ConsumeAToken(")"); + if (attribute_name == "IntArray") { + return IntArrayAttribute::Parse(parser); + } else if (attribute_name == "DataType") { + return DataTypeAttribute::Parse(parser); + } else if (attribute_name == "Place") { + return PlaceAttribute::Parse(parser); + } else if (attribute_name == "DataLayout") { + return DataLayoutAttribute::Parse(parser); + } else { + IR_THROW("No function to parse " + attribute_name + " exists!" + + parser.GetErrorLocationInfo()); + } +} + void PaddleDialect::PrintOperation(ir::Operation *op, ir::IrPrinter &printer) const { if (auto if_op = op->dyn_cast()) { diff --git a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h index b9e9567e790..285a796982f 100644 --- a/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h +++ b/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h @@ -25,6 +25,9 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } + ir::Type ParseType(ir::IrParser& parser) override; // NOLINT + ir::Attribute ParseAttribute(ir::IrParser& parser) override; // NOLINT + void PrintType(ir::Type type, std::ostream& os) const override; void PrintAttribute(ir::Attribute type, std::ostream& os) const override; diff --git a/paddle/ir/core/CMakeLists.txt b/paddle/ir/core/CMakeLists.txt index 39a3d71b271..138b102fcbd 100644 --- a/paddle/ir/core/CMakeLists.txt +++ b/paddle/ir/core/CMakeLists.txt @@ -3,4 +3,8 @@ set(NEWIR_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/ir") file(GLOB IR_SRCS "*.cc") +file(GLOB IR_PARSER_SRCS "parser/*.cc") + +list(APPEND IR_SRCS ${IR_PARSER_SRCS}) + ir_library(ir_core SRCS ${IR_SRCS} DEPS ddim) diff --git a/paddle/ir/core/attribute.h b/paddle/ir/core/attribute.h index 4315e13b0fc..d83ea3b3c60 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/ir/core/attribute.h @@ -68,6 +68,8 @@ class IR_API Attribute { /// @param os void Print(std::ostream &os) const; + static Attribute Parse(std::istream &is, IrContext *ctx); + /// /// \brief Methods for type judgment and cast. /// diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index be67898dd98..f07a4242f36 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -29,7 +29,7 @@ namespace ir { class Operation; class IrPrinter; - +class IrParser; class DialectInterface; /// /// \brief Dialect can basically be understood as a namespace. In Dialect, we @@ -145,9 +145,21 @@ class IR_API Dialect { IR_THROW("dialect has no registered attribute printing hook"); } + virtual Type ParseType(IrParser &parser) { // NOLINT + IR_THROW("dialect has no registered type parsing hook"); + } + + virtual Attribute ParseAttribute(IrParser &parser) { // NOLINT + IR_THROW("dialect has no registered attribute parsing hook"); + } + virtual void PrintOperation(Operation *op, IrPrinter &printer) const; // NOLINT + virtual Operation ParseOperation(IrParser &parser) { // NOLINT + IR_THROW("dialect has no registered operation parsing hook"); + } + private: Dialect(const Dialect &) = delete; diff --git a/paddle/ir/core/ir_parser.h b/paddle/ir/core/ir_parser.h new file mode 100644 index 00000000000..dbba3e2aaba --- /dev/null +++ b/paddle/ir/core/ir_parser.h @@ -0,0 +1,71 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/operation.h" +#include "paddle/ir/core/parser/lexer.h" +#include "paddle/ir/core/program.h" + +using OpResultMap = std::map; +using AttributeMap = std::unordered_map; +using OpAttributeInfoMap = std::map; + +namespace ir { +class IrParser { + public: + std::unique_ptr lexer; + IrContext* ctx; + OpResultMap opresultmap; + std::unique_ptr builder; + + public: + IrParser(IrContext* ctx, std::istream& is); + + ~IrParser() = default; + + Token ConsumeToken(); + + Token PeekToken(); + + std::unique_ptr ParseProgram(); + + void ParseRegion(Region& region); // NOLINT + + void ParseBlock(Block& block); // NOLINT + + Operation* ParseOperation(); + + OpInfo ParseOpInfo(); + + std::vector ParseOpResultList(); + + std::vector ParseOpRandList(); + + AttributeMap ParseAttributeMap(); + + std::vector ParseTypeList(); + + OpResult GetNullValue(); + + Type ParseType(); + + Attribute ParseAttribute(); + + std::string GetErrorLocationInfo(); + + void ConsumeAToken(std::string expect_token_val); +}; + +} // namespace ir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index 16d6568ecc4..0d0ce64f679 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -87,22 +87,27 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { } if (auto s = attr.dyn_cast()) { - os << s.AsString(); + os << "(String)" << s.AsString(); } else if (auto b = attr.dyn_cast()) { - os << b.data(); + if (b.data()) { + os << "true"; + } else { + os << "false"; + } } else if (auto f = attr.dyn_cast()) { - os << f.data(); + os << "(Float)" << f.data(); } else if (auto d = attr.dyn_cast()) { - os << d.data(); + os << "(Double)" << d.data(); } else if (auto i = attr.dyn_cast()) { - os << i.data(); + os << "(Int32)" << i.data(); } else if (auto i = attr.dyn_cast()) { - os << i.data(); + os << "(Int64)" << i.data(); } else if (auto p = attr.dyn_cast()) { - os << p.data(); + os << "(Pointer)" << p.data(); } else if (auto arr = attr.dyn_cast()) { const auto& vec = arr.AsVector(); - os << "array["; + os << "(Array)" + << "["; PrintInterleave( vec.begin(), vec.end(), diff --git a/paddle/ir/core/parser/ir_parser.cc b/paddle/ir/core/parser/ir_parser.cc new file mode 100644 index 00000000000..8d7e4376351 --- /dev/null +++ b/paddle/ir/core/parser/ir_parser.cc @@ -0,0 +1,351 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/core/ir_parser.h" + +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/builtin_type.h" + +namespace ir { +IrParser::IrParser(IrContext* ctx, std::istream& is) { + lexer.reset(new Lexer{is}); + this->ctx = ctx; + builder.reset(new Builder{ctx}); +} + +Token IrParser::ConsumeToken() { + auto token = lexer->ConsumeToken(); + return token; +} + +std::string IrParser::GetErrorLocationInfo() { + return "The error occurred in line " + std::to_string(lexer->GetLine()) + + ", column " + std::to_string(lexer->GetColumn()); +} + +Token IrParser::PeekToken() { + auto token = lexer->ConsumeToken(); + if (token.token_type_ != EOF_) { + lexer->Unget(token.val_.size()); + } + return token; +} + +void IrParser::ConsumeAToken(std::string expect_token_val) { + std::string token_val = ConsumeToken().val_; + IR_ENFORCE(token_val == expect_token_val, + "The token value of expectation is " + expect_token_val + " ,not" + + token_val + "." + GetErrorLocationInfo()); +} + +// Type := BuiltinType | OtherDialectsDefineType +// BuiltinType := <> | bf16 | f16 | f32 | f64 +// := | b | i8 | u8 | i16 | i32 | i64 | index | c64 +// := | c128 | VectorType +// VectorType := '[' Type(,Type)* ']' +Type IrParser::ParseType() { + Token type_token = PeekToken(); + std::string type_val = type_token.val_; + if (type_val == "<>") { + ConsumeToken(); + return Type(nullptr); + } else if (type_val == "bf16") { + ConsumeToken(); + return builder->bfloat16_type(); + } else if (type_val == "f16") { + ConsumeToken(); + return builder->bfloat16_type(); + } else if (type_val == "f32") { + ConsumeToken(); + return builder->float32_type(); + } else if (type_val == "f64") { + ConsumeToken(); + return builder->float64_type(); + } else if (type_val == "b") { + ConsumeToken(); + return builder->bool_type(); + } else if (type_val == "i8") { + ConsumeToken(); + return builder->int8_type(); + } else if (type_val == "u8") { + ConsumeToken(); + return builder->uint8_type(); + } else if (type_val == "i16") { + ConsumeToken(); + return builder->int16_type(); + } else if (type_val == "i32") { + ConsumeToken(); + return Int32Type::get(ctx); + } else if (type_val == "i64") { + ConsumeToken(); + return Int64Type::get(ctx); + } else if (type_val == "index") { + ConsumeToken(); + return IndexType::get(ctx); + } else if (type_val == "c64") { + ConsumeToken(); + return builder->complex64_type(); + } else if (type_val == "c128") { + ConsumeToken(); + return builder->complex128_type(); + } else if (type_val == "vec") { + ConsumeAToken("vec"); + ConsumeAToken("["); + std::vector vec_type; + Token vec_type_token = PeekToken(); + while (vec_type_token.val_ != "]") { + Type cur_type = ParseType(); + vec_type.push_back(cur_type); + vec_type_token = ConsumeToken(); + } + return VectorType::get(ctx, vec_type); + } else { + IR_ENFORCE(type_val.find('.') != std::string::npos, + "No function parsing " + type_val + " exists!" + + GetErrorLocationInfo()); + auto dialect_name = type_val.substr(0, type_val.find('.')); + auto dialect = ctx->GetRegisteredDialect(dialect_name); + return dialect->ParseType(*this); + } +} + +// Attribute := BuiltinAttribute | OtherDialectsDefineAttribute +// BuiltinAttribute := Bool | String | Float | Double | Int32 | +// := | Int64 | Pointer | ArrayAttribute +// ArrayAttribute := '[' Atribute(,Attribute)* ']' +Attribute IrParser::ParseAttribute() { + auto parenthesis_token = ConsumeToken(); + if (parenthesis_token.val_ == "true" || parenthesis_token.val_ == "false") { + return builder->bool_attr(parenthesis_token.val_ == "true"); + } + std::string attribute_type = PeekToken().val_; + if (attribute_type == "String") { + ConsumeAToken("String"); + ConsumeAToken(")"); + std::string val = ConsumeToken().val_; + return builder->str_attr(val); + } else if (attribute_type == "Float") { + ConsumeAToken("Float"); + ConsumeAToken(")"); + std::string val = ConsumeToken().val_; + return builder->float_attr(atof(val.c_str())); + } else if (attribute_type == "Double") { + ConsumeAToken("Double"); + ConsumeAToken(")"); + std::string val = ConsumeToken().val_; + return builder->double_attr(atof(val.c_str())); + } else if (attribute_type == "Int32") { + ConsumeAToken("Int32"); + ConsumeAToken(")"); + std::string val = ConsumeToken().val_; + return builder->int32_attr(atoi(val.c_str())); + } else if (attribute_type == "Int64") { + ConsumeAToken("Int64"); + ConsumeAToken(")"); + std::string val = ConsumeToken().val_; + return builder->int64_attr(atoll(val.c_str())); + } else if (attribute_type == "Pointer") { + IR_THROW("This attribute is not currently supported by parser"); + } else if (attribute_type == "Array") { + ConsumeAToken("Array"); + ConsumeAToken(")"); + ConsumeAToken("["); + std::vector array_attribute; + while (PeekToken().val_ != "]") { + array_attribute.push_back(ParseAttribute()); + if (PeekToken().val_ == "]") break; + ConsumeAToken(","); + } + ConsumeAToken("]"); + return builder->array_attr(array_attribute); + } else { + IR_ENFORCE(attribute_type.find('.') != std::string::npos, + "No function parsing " + attribute_type + " exists!" + + GetErrorLocationInfo()); + auto dialect_name = attribute_type.substr(0, attribute_type.find('.')); + auto dialect = ctx->GetRegisteredDialect(dialect_name); + return dialect->ParseAttribute(*this); + } +} + +// Program := [ParameterList]ModuleOp +// ModuleOp := Region +std::unique_ptr IrParser::ParseProgram() { + std::unique_ptr program(new Program{ctx}); + auto top_level_op = program->module_op(); + auto& region = top_level_op->region(0); + ParseRegion(region); + + return program; +} + +// Region := Block +void IrParser::ParseRegion(Region& region) { // NOLINT + ParseBlock(*region.front()); + IR_ENFORCE(PeekToken().val_ != "{", + "Only one block in a region is supported"); +} + +// Block := "{" {Operation} "}" +void IrParser::ParseBlock(Block& block) { // NOLINT + ConsumeAToken("{"); + while (PeekToken().val_ != "}") { + auto op = ParseOperation(); + block.push_back(op); + } + ConsumeAToken("}"); +} + +// Operation := OpResultList ":=" Opname "(" OprandList ? ")" AttributeMap ":" +// FunctionType +// FunctionType := "(" TypeList ")" "->" TypeList +Operation* IrParser::ParseOperation() { + std::vector opresultindex = ParseOpResultList(); + ConsumeAToken("="); + + OpInfo opinfo = ParseOpInfo(); + + std::vector inputs = ParseOpRandList(); + + ir::AttributeMap attributeMap = ParseAttributeMap(); + + ConsumeAToken(":"); + ConsumeAToken("("); + ParseTypeList(); + ConsumeAToken(")"); + ConsumeAToken("->"); + + std::vector type_vector = ParseTypeList(); + + Operation* op = + Operation::Create(inputs, attributeMap, type_vector, opinfo, 0); + + for (uint32_t i = 0; i < op->num_results(); i++) { + std::string key_t = opresultindex[i]; + opresultmap[key_t] = op->result(i); + } + + return op; +} + +// OpResultList := ValueList +// ValueList := ValueId(,ValueId)* +std::vector IrParser::ParseOpResultList() { + std::vector opresultindex{}; + ConsumeAToken("("); + Token index_token = ConsumeToken(); + while (index_token.val_ != ")") { + if (index_token.token_type_ == NULL_) { + opresultindex.push_back("null"); + } else { + std::string str = index_token.val_; + opresultindex.push_back(str); + } + if (ConsumeToken().val_ == ")") break; + index_token = ConsumeToken(); + } + + return opresultindex; +} + +// OpName := "\"" StringIdentifer "." StringIdentifer "\"" +OpInfo IrParser::ParseOpInfo() { + Token opname_token = ConsumeToken(); + std::string opname = + opname_token.val_.substr(1, opname_token.val_.size() - 2); + return ctx->GetRegisteredOpInfo(opname); +} + +// OprandList := ValueList +// ValueList := ValueId(,ValueId)* +std::vector IrParser::ParseOpRandList() { + ConsumeAToken("("); + std::vector inputs{}; + Token ind_token = ConsumeToken(); + while (ind_token.val_ != ")") { + std::string t = ""; + if (ind_token.token_type_ == NULL_) { + inputs.push_back(GetNullValue()); + } else { + t = ind_token.val_; + inputs.push_back(opresultmap[t]); + } + Token token = ConsumeToken(); + if (token.val_ == ")") { + break; + } + ind_token = ConsumeToken(); + } + return inputs; +} + +// AttributeMap := "{" AttributeEntry,(,AttributeEntry)* "}" +// AttributeEntry := StringIdentifer:Attribute +AttributeMap IrParser::ParseAttributeMap() { + AttributeMap attribute_map{}; + ConsumeAToken("{"); + Token key_token = ConsumeToken(); + while (key_token.val_ != "}") { + ConsumeAToken(":"); + attribute_map[key_token.val_] = ParseAttribute(); + std::string token_val = ConsumeToken().val_; + if (token_val == "}") { + break; + } else if (token_val == ",") { + key_token = ConsumeToken(); + } else { + IR_ENFORCE((token_val == "}") || (token_val == ","), + "The token value of expectation is } or , , not " + token_val + + "." + GetErrorLocationInfo()); + } + } + return attribute_map; +} + +// TypeList := Type(,Type)* +std::vector IrParser::ParseTypeList() { + std::vector type_vector{}; + while (PeekToken().val_ != "(" && PeekToken().val_ != "}" && + PeekToken().val_ != ")") { + type_vector.push_back(ParseType()); + if (PeekToken().val_ == "}" || PeekToken().val_ == "(" || + PeekToken().val_ == ")" || PeekToken().token_type_ == EOF_) + break; + ConsumeAToken(","); + } + return type_vector; +} + +OpResult IrParser::GetNullValue() { + Value* v = new Value{nullptr}; + OpResult* opresult = static_cast(v); + return *opresult; +} + +Attribute Attribute::Parse(std::istream& is, IrContext* ctx) { + IrParser parser(ctx, is); + return parser.ParseAttribute(); +} + +Type Type::Parse(std::istream& is, IrContext* ctx) { + IrParser parser(ctx, is); + return parser.ParseType(); +} + +std::unique_ptr Program::Parse(std::istream& is, IrContext* ctx) { + IrParser parser(ctx, is); + return parser.ParseProgram(); +} + +} // namespace ir diff --git a/paddle/ir/core/parser/lexer.cc b/paddle/ir/core/parser/lexer.cc new file mode 100644 index 00000000000..af1530a5b29 --- /dev/null +++ b/paddle/ir/core/parser/lexer.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/core/parser/lexer.h" + +Token Lexer::ConsumeToken() { + SkipWhitespace(); + if (auto token = LexIdentifer()) { + return *token; + } else if (auto token = LexNumberOrArraow()) { + return *token; + } else if (auto token = LexEndTagOrNullVal()) { + return *token; + } else if (auto token = LexValueId()) { + return *token; + } else if (auto token = LexOpName()) { + return *token; + } else if (auto token = LexEOF()) { + return *token; + } else { + return Token{"Error", NULL_}; + } +} + +char Lexer::GetChar() { + char c = is.get(); + if (c == '\n') { + line++; + column = 1; + } else { + column++; + } + return c; +} + +size_t Lexer::GetColumn() { return column; } + +size_t Lexer::GetLine() { return line; } + +void Lexer::SkipWhitespace() { + while (IsSpace(is.peek())) { + GetChar(); + } +} + +std::unique_ptr Lexer::LexIdentifer() { + if ((!isalpha(is.peek()) && is.peek() != '_') || IsEndTag(is.peek())) { + return nullptr; + } + std::string token_identifier = ""; + while (isalnum(is.peek()) || is.peek() == '_' || is.peek() == '.') { + token_identifier += GetChar(); + } + std::unique_ptr token(new Token{token_identifier, IDENTIFER}); + return token; +} + +std::unique_ptr Lexer::LexNumberOrArraow() { + if (!isdigit(is.peek()) && is.peek() != '-') { + return nullptr; + } + + std::string token_digit = ""; + token_digit += GetChar(); + + if (token_digit[0] == '-' && is.peek() == '>') { + GetChar(); + std::unique_ptr arrow_token(new Token{"->", ARRAOW}); + return arrow_token; + } + while (isdigit(is.peek())) { + token_digit += GetChar(); + } + if (is.peek() == '.') { + token_digit += GetChar(); + while (isdigit(is.peek())) { + token_digit += GetChar(); + } + } + if (is.peek() == 'e') { + token_digit += GetChar(); + if (is.peek() == '+' || is.peek() == '-') { + token_digit += GetChar(); + } + while (isdigit(is.peek())) { + token_digit += GetChar(); + } + std::unique_ptr sdigit_token(new Token{token_digit, SDIGIT}); + return sdigit_token; + } + std::unique_ptr digit_token(new Token{token_digit, DIGIT}); + return digit_token; +} + +std::unique_ptr Lexer::LexEndTagOrNullVal() { + if (!IsEndTag(is.peek())) { + return nullptr; + } + std::string token_end = ""; + token_end += GetChar(); + if ((token_end[0] == '<' && (is.peek() != '<' && is.peek() != '#')) || + token_end[0] != '<') { + std::unique_ptr endtag_token(new Token{token_end, ENDTAG}); + return endtag_token; + } + if (is.peek() == '<') { + std::string token_null_val = ""; + GetChar(); + while (is.peek() != '>') { + token_null_val += GetChar(); + } + GetChar(); + GetChar(); + std::unique_ptr null_token( + new Token{"<<" + token_null_val + ">>", NULL_}); + return null_token; + } else { + std::string token_attrnull = ""; + while (is.peek() != '>') { + token_attrnull += GetChar(); + } + GetChar(); + std::unique_ptr null_token( + new Token{"<" + token_attrnull + ">", NULL_}); + return null_token; + } +} + +std::unique_ptr Lexer::LexValueId() { + if (is.peek() != '%') { + return nullptr; + } + std::string token_valueid = ""; + token_valueid += GetChar(); + + while (isdigit(is.peek())) { + token_valueid += GetChar(); + } + std::unique_ptr valueid_token(new Token{token_valueid, VALUEID}); + return valueid_token; +} + +std::unique_ptr Lexer::LexEOF() { + if (is.peek() == EOF) { + std::unique_ptr eof_token(new Token{"LEX_DOWN", EOF_}); + return eof_token; + } else { + return nullptr; + } +} + +std::unique_ptr Lexer::LexOpName() { + if (is.peek() != '"') { + return nullptr; + } + GetChar(); + std::string token_opname = ""; + while (is.peek() != '"') { + token_opname += GetChar(); + } + GetChar(); + std::unique_ptr opname_token( + new Token{"\"" + token_opname + "\"", OPNAME}); + return opname_token; +} + +bool Lexer::IsSpace(char c) { + return c == ' ' || c == '\n' || c == '\t' || c == '\f'; +} + +bool Lexer::IsEndTag(char c) { + return c == '{' || c == '}' || c == '(' || c == ')' || c == ':' || c == '>' || + c == ',' || c == ']' || c == '[' || c == '+' || c == '=' || c == '<'; +} + +void Lexer::Unget(const int len) { + if (is.eof()) { + is.clear(); + } + column -= len; + is.seekg(-len, std::ios::cur); +} diff --git a/paddle/ir/core/parser/lexer.h b/paddle/ir/core/parser/lexer.h new file mode 100644 index 00000000000..0561e1f60ca --- /dev/null +++ b/paddle/ir/core/parser/lexer.h @@ -0,0 +1,44 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/ir/core/parser/token.h" + +class Lexer { + private: + std::istream& is; + size_t line = 1; + size_t column = 1; + + public: + explicit Lexer(std::istream& is) : is(is) {} + ~Lexer() = default; + Token ConsumeToken(); + std::unique_ptr LexIdentifer(); + std::unique_ptr LexNumberOrArraow(); + std::unique_ptr LexEndTagOrNullVal(); + std::unique_ptr LexValueId(); + std::unique_ptr LexEOF(); + std::unique_ptr LexOpName(); + char GetChar(); + void SkipWhitespace(); + bool IsEndTag(char); + bool IsSpace(char); + size_t GetLine(); + size_t GetColumn(); + void Unget(const int len); +}; diff --git a/paddle/ir/core/parser/token.h b/paddle/ir/core/parser/token.h new file mode 100644 index 00000000000..78a20a691c8 --- /dev/null +++ b/paddle/ir/core/parser/token.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +enum Token_type { + EOF_ = -1, + IDENTIFER = 0, + DIGIT = 1, + SDIGIT = 2, + ENDTAG = 3, + VALUEID = 4, + OPNAME = 5, + ARRAOW = 6, + NULL_ = 7, +}; + +struct Token { + public: + std::string val_; + Token_type token_type_; + Token() = default; + Token(const std::string& val, Token_type token_type) { + val_ = val; + token_type_ = token_type; + } +}; diff --git a/paddle/ir/core/program.h b/paddle/ir/core/program.h index 6f44a3fe469..bf9c3721096 100644 --- a/paddle/ir/core/program.h +++ b/paddle/ir/core/program.h @@ -52,6 +52,8 @@ class IR_API Program { void Print(std::ostream& os) const; + static std::unique_ptr Parse(std::istream& is, IrContext* ctx); + Block* block() { return module_.block(); } const Block* block() const { return module_op().block(); } diff --git a/paddle/ir/core/type.h b/paddle/ir/core/type.h index df148f17a23..f27503b3731 100644 --- a/paddle/ir/core/type.h +++ b/paddle/ir/core/type.h @@ -84,6 +84,8 @@ class IR_API Type { void Print(std::ostream &os) const; + static Type Parse(std::istream &is, IrContext *ctx); + /// /// \brief Enable hashing Type. /// diff --git a/test/cpp/ir/core/CMakeLists.txt b/test/cpp/ir/core/CMakeLists.txt index 80cd506648c..14ea9dc1372 100644 --- a/test/cpp/ir/core/CMakeLists.txt +++ b/test/cpp/ir/core/CMakeLists.txt @@ -65,6 +65,9 @@ file( ${CMAKE_CURRENT_BINARY_DIR}/resnet50_startup.prog EXPECTED_MD5 6affc5f40f0f0bb84d956919b95eaf50) +copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/TestParserText.txt + ${CMAKE_CURRENT_BINARY_DIR}/TestParserText.txt) + cc_test_old( program_translator_test SRCS @@ -75,6 +78,24 @@ cc_test_old( pd_dialect ir) +cc_test_old( + add_dialect_parser_test + SRCS + add_dialect_parser_test.cc + DEPS + gtest + pd_dialect + ir) + +cc_test_old( + ir_parser_test + SRCS + ir_parser_test.cc + DEPS + gtest + pd_dialect + ir) + cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest ir) cc_test_old( ir_op_yaml_info_parser_test diff --git a/test/cpp/ir/core/TestParserText.txt b/test/cpp/ir/core/TestParserText.txt new file mode 100644 index 00000000000..e90248086eb --- /dev/null +++ b/test/cpp/ir/core/TestParserText.txt @@ -0,0 +1,43 @@ + +//CHECK attribute +(String)sdfgs.sdsd + +//CHECK type +f32 + +//CHECK type +pd.tensor<256xf32> + +//CHECK program +{ + (%0) = "builtin.get_parameter" () {parameter_name:(String)conv2d_0.w_0} : () -> pd.tensor<64x3x7x7xf32> + (%1) = "pd.feed" () {col:(Int32)0,is_persisable:(Array)[false],name:(String)data,stop_gradient:(Array)[true]} : () -> pd.tensor<-1x3x224x224xf32> + (%2) = "pd.conv2d" (%1, %0) {data_format:(String)NCHW,dilations:(Array)[(Int32)1,(Int32)1],groups:(Int32)1,is_persisable:(Array)[false],padding_algorithm:(String)EXPLICIT,paddings:(Array)[(Int32)3,(Int32)3],stop_gradient:(Array)[false],strides:(Array)[(Int32)2,(Int32)2]} : (pd.tensor<-1x3x224x224xf32>, pd.tensor<64x3x7x7xf32>) -> pd.tensor<-1x64x112x112xf32> +} + +//CHECK attribute +(Array)[(pd.DataType)bool,(pd.DataType)float32,(pd.DataType)float64, +(pd.DataType)complex64,(pd.DataType)complex128,(pd.DataType)Undefined, +(pd.DataType)Undefined,(pd.DataType)Undefined,(pd.DataType)Undefined, +(pd.DataType)bfloat16,(pd.DataType)uint8,(pd.DataType)uint32,(pd.DataType)int8, +(pd.DataType)uint16,(pd.DataType)int16,(pd.DataType)int32,(pd.DataType)uint64,(pd.DataType)int64] + + +//CHECK attribute +(Array)[(pd.Place)Place(gpu:0),(pd.Place)Place(gpu_pinned),(pd.Place)Place(gpu_pinned), +(pd.Place)Place(xpu:0),(pd.Place)Place(ipu:0),(pd.Place)Place(:0),(pd.Place)Place(cpu)] + + +//CHECK attribute +(Array)[(pd.DataLayout)NHWC,(pd.DataLayout)STRIDED,(pd.DataLayout)NCHW,(pd.DataLayout)Undefined(AnyLayout), +(pd.DataLayout)ONEDNN,(pd.DataLayout)SPARSE_COO,(pd.DataLayout)SPARSE_CSR,(pd.DataLayout)NDHWC,(pd.DataLayout)NCDHW, +(pd.DataLayout)PSTRING_UNION] + +//CHECK attribute +(Array)[(Double)1,(Int64)0,(String)1] + +//CHECK type +vec[bf16,f64,b,i8,u8,i16,c64,c128] + +//CHECK attribute +(String)1 diff --git a/test/cpp/ir/core/add_dialect_parser_test.cc b/test/cpp/ir/core/add_dialect_parser_test.cc new file mode 100644 index 00000000000..9bc39bb8d96 --- /dev/null +++ b/test/cpp/ir/core/add_dialect_parser_test.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_attribute_storage.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_parser.h" +#include "paddle/ir/core/utils.h" + +using PaddleDialect = paddle::dialect::PaddleDialect; +using AttributeStorage = ir::AttributeStorage; + +class TestParserDialect : public ir::Dialect { + public: + explicit TestParserDialect(ir::IrContext* context); + + static const char* name() { return "tp"; } + + void PrintAttribute(ir::Attribute attr, std::ostream& os) const; + + ir::Attribute ParseAttribute(ir::IrParser& parser); // NOLINT + + private: + void initialize(); +}; + +IR_DECLARE_EXPLICIT_TYPE_ID(TestParserDialect); +IR_DEFINE_EXPLICIT_TYPE_ID(TestParserDialect); + +DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(CharAttributeStorage, char); + +class CharAttribute : public ir::Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CharAttribute, CharAttributeStorage); + + char data() const; + + static CharAttribute Parse(ir::IrParser& parser) { // NOLINT + std::string char_val = parser.ConsumeToken().val_; + return CharAttribute::get(parser.ctx, char_val[0]); + } +}; + +IR_DECLARE_EXPLICIT_TYPE_ID(CharAttribute); + +IR_DEFINE_EXPLICIT_TYPE_ID(CharAttribute); + +void TestParserDialect::initialize() { RegisterAttributes(); } + +char CharAttribute::data() const { return storage()->data(); } + +TestParserDialect::TestParserDialect(ir::IrContext* context) + : ir::Dialect(name(), context, ir::TypeId::get()) { + initialize(); +} + +void TestParserDialect::PrintAttribute(ir::Attribute attr, + std::ostream& os) const { + auto byte_attr = attr.dyn_cast(); + os << "(tp.char)" << byte_attr.data(); +} + +ir::Attribute TestParserDialect::ParseAttribute( + ir::IrParser& parser) { // NOLINT + std::string type_name = parser.ConsumeToken().val_; + std::string parenthesis_token_val = parser.ConsumeToken().val_; + IR_ENFORCE(parenthesis_token_val == ")", + "The token value of expectation is ), not " + + parenthesis_token_val + "." + parser.GetErrorLocationInfo()); + return CharAttribute::Parse(parser); +} + +TEST(IrParserTest, AddAttribute) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + std::string op_str = + " (%0) = \"builtin.get_parameter\" () " + "{parameter_name:(String)conv2d_0.w_0,test:(tp.char)a} : () -> " + "pd.tensor<64x3x7x7xf32>"; + std::stringstream ss; + ss << op_str; + ir::IrParser* parser = new ir::IrParser(ctx, ss); + ir::Operation* op = parser->ParseOperation(); + std::stringstream ssp; + op->Print(ssp); + delete parser; + EXPECT_TRUE(ssp.str() == ss.str()); +} diff --git a/test/cpp/ir/core/ir_parser_test.cc b/test/cpp/ir/core/ir_parser_test.cc new file mode 100644 index 00000000000..39abf960583 --- /dev/null +++ b/test/cpp/ir/core/ir_parser_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/ir_adaptor/translator/translate.h" +#include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_attribute_storage.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_parser.h" +#include "paddle/ir/core/ir_printer.h" +#include "paddle/ir/core/utils.h" + +using PaddleDialect = paddle::dialect::PaddleDialect; +using AttributeStorage = ir::AttributeStorage; + +enum TestType { + AttributeTest = 0, + TypeTest = 1, + ProgramTest = 2, +}; + +class TestTask { + public: + TestType test_type; + std::string test_info; + + public: + TestTask(TestType test_type, std::string test_info) { + this->test_info = test_info; + this->test_type = test_type; + } +}; + +class ParserTest { + private: + std::ifstream& test_text; + + public: + explicit ParserTest(std::ifstream& test_text) : test_text(test_text) {} + TestTask* GetTestTask(); + bool ConsumeTestTask(TestTask* test_task, ir::IrContext* ctx); +}; + +TestTask* ParserTest::GetTestTask() { + if (test_text.peek() == EOF) { + return nullptr; + } + std::string test_info; + while (test_text.peek() != '/') { + test_text.get(); + } + while (test_text.peek() != ' ') { + test_text.get(); + } + test_text.get(); + std::string test_type_info; + while (test_text.peek() != '\n') { + test_type_info += test_text.get(); + } + test_text.get(); + while (test_text.peek() != '/' && test_text.peek() != EOF) { + test_info += test_text.get(); + } + if (test_type_info == "attribute") { + return new TestTask(AttributeTest, test_info); + } else if (test_type_info == "type") { + return new TestTask(TypeTest, test_info); + } else if (test_type_info == "program") { + return new TestTask(ProgramTest, test_info); + } + return nullptr; +} + +bool ParserTest::ConsumeTestTask(TestTask* test_task, ir::IrContext* ctx) { + std::string test_info = test_task->test_info; + TestType test_type = test_task->test_type; + std::unique_ptr printer; + std::unique_ptr parser; + std::stringstream is(test_info); + parser.reset(new ir::IrParser(ctx, is)); + std::vector before_parser_tokens; + while (parser->PeekToken().token_type_ != EOF_) { + before_parser_tokens.push_back(parser->ConsumeToken().val_); + } + std::stringstream is_par(test_info); + std::stringstream os; + if (test_type == AttributeTest) { + auto attr = ir::Attribute::Parse(is_par, ctx); + attr.Print(os); + } else if (test_type == ProgramTest) { + auto program = ir::Program::Parse(is_par, ctx); + program->Print(os); + } else if (test_type == TypeTest) { + auto type = ir::Type::Parse(is_par, ctx); + type.Print(os); + } + parser.reset(new ir::IrParser(ctx, os)); + std::vector after_parser_tokens; + while (parser->PeekToken().token_type_ != EOF_) { + auto str = parser->ConsumeToken().val_; + after_parser_tokens.push_back(str); + } + delete test_task; + if (after_parser_tokens.size() != before_parser_tokens.size()) { + return false; + } + + for (size_t i = 0; i < after_parser_tokens.size(); i++) { + if (after_parser_tokens[i] != before_parser_tokens[i]) { + return false; + } + } + + return true; +} + +TEST(IrParserTest, TestParserByFile) { + ir::IrContext* ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + std::ifstream is("TestParserText.txt"); + EXPECT_TRUE(is.is_open()); + ParserTest parser_test(is); + bool is_test = false; + while (TestTask* test_task = parser_test.GetTestTask()) { + is_test = true; + bool ans = parser_test.ConsumeTestTask(test_task, ctx); + EXPECT_TRUE(ans); + } + is.close(); + EXPECT_TRUE(is_test); +} diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index c1c89dc9f78..0441860ed1d 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -31,6 +31,8 @@ #include "paddle/ir/core/builtin_dialect.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/ir_parser.h" +#include "paddle/ir/core/ir_printer.h" #include "paddle/ir/core/program.h" using PaddleDialect = paddle::dialect::PaddleDialect; @@ -107,3 +109,37 @@ TEST(RegisterInfoTest, MainProgram) { EXPECT_EQ(unregistered_ops.size(), 1u); EXPECT_EQ(unregistered_ops[0], "something must not be registered"); } + +TEST(IrParserTest, MainProgram) { + auto p = load_from_file("resnet50_main.prog"); + EXPECT_EQ(p.Size(), 1u); + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + std::stringstream ss; + program->Print(ss); + std::unique_ptr parser_program = ir::Program::Parse(ss, ctx); + std::stringstream ssp; + parser_program->Print(ssp); + + EXPECT_TRUE(ssp.str() == ss.str()); +} + +TEST(IrParserTest, StartupProgram) { + auto p = load_from_file("resnet50_startup.prog"); + EXPECT_EQ(p.Size(), 1u); + ir::IrContext *ctx = ir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + auto program = paddle::TranslateLegacyProgramToProgram(p); + + std::stringstream ss; + program->Print(ss); + std::unique_ptr parser_program = ir::Program::Parse(ss, ctx); + std::stringstream ssp; + parser_program->Print(ssp); + + EXPECT_TRUE(ssp.str() == ss.str()); +} -- GitLab