From 3143d8bf9b361bcfbc932f8f447df191b8c9136a Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Thu, 25 May 2023 14:36:44 +0800 Subject: [PATCH] [IR] Refine Builder (#54052) * refine code * delete some unused code * refine code of build * refine code of build * add block * refine builder * refine code * refine code by comment * fix compiler bug --- paddle/fluid/dialect/legacy_pd_op.h | 74 ++++++++++++++++++----------- paddle/fluid/dialect/pd_attribute.h | 4 -- paddle/fluid/dialect/pd_dialect.cc | 20 ++++---- paddle/fluid/dialect/pd_type.h | 2 - paddle/ir/block.cc | 26 ++++++++++ paddle/ir/block.h | 54 +++++++++++++++++++++ paddle/ir/builder.cc | 41 ++++++++++++++++ paddle/ir/builder.h | 40 ++++++++++++---- paddle/ir/builtin_attribute.h | 7 --- paddle/ir/builtin_type.h | 9 ---- paddle/ir/op_base.h | 22 ++++----- paddle/ir/operation.h | 7 +-- paddle/ir/printer.cc | 6 ++- paddle/ir/program.cc | 8 +--- paddle/ir/program.h | 6 +-- test/cpp/ir/ir_op_test.cc | 71 ++++++++++++++++----------- test/cpp/ir/ir_program_test.cc | 10 ++-- test/cpp/ir/ir_value_test.cc | 16 +++---- 18 files changed, 288 insertions(+), 135 deletions(-) create mode 100644 paddle/ir/block.cc create mode 100644 paddle/ir/block.h create mode 100644 paddle/ir/builder.cc diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h index 6e64cad575a..21be24720dc 100644 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ b/paddle/fluid/dialect/legacy_pd_op.h @@ -25,40 +25,60 @@ namespace dialect { class className : public ir::Op { \ public: \ static const char *name() { return OPNAME(op_name); } \ - static const char **attributes_name; \ + static constexpr const char **attributes_name = nullptr; \ static constexpr uint32_t attributes_num = 0; \ static void verify(const std::vector &inputs, \ const std::vector &outputs, \ const ir::AttributeMap &attributes) { \ LOG(WARNING) << "This is a fake verify"; \ } \ - }; \ - const char **className::attributes_name = nullptr; + }; -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); -REIGSTER_EMPTY_OP(feed, FeedOp); -REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); -REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); -REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp); -REIGSTER_EMPTY_OP(pool2d, Pool2DOp); -REIGSTER_EMPTY_OP(flatten_contiguous_range, FlattenContiguousRangeOp); -REIGSTER_EMPTY_OP(matmul_v2, MatmulV2Op); -REIGSTER_EMPTY_OP(reshape2, Reshape2Op); -REIGSTER_EMPTY_OP(softmax_with_cross_entropy, SoftmaxWithCrossEntropyOp); -REIGSTER_EMPTY_OP(reduce_mean, ReduceMeanOp); -REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); -REIGSTER_EMPTY_OP(fill_constant, FillConstantOp); -REIGSTER_EMPTY_OP(reduce_mean_grad, ReduceMeanGradOp); -REIGSTER_EMPTY_OP(softmax_with_cross_entropy_grad, - SoftmaxWithCrossEntropyGradOp); -REIGSTER_EMPTY_OP(elementwise_add_grad, ElementwiseAddGradOp); -REIGSTER_EMPTY_OP(matmul_v2_grad, MatmulV2GradOp); -REIGSTER_EMPTY_OP(flatten_contiguous_range_grad, FlattenContiguousRangeGradOp); -REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); -REIGSTER_EMPTY_OP(batch_norm_grad, BatchNormGradOp); -REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); -REIGSTER_EMPTY_OP(sum, SumOp); -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); +// TODO(zhangbo): As operators are supplemented and defined, they are gradually +// removed. +REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d +REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed +REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm +REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ +REIGSTER_EMPTY_OP(elementwise_add, + ElementwiseAddOp); // To be customized: add (elementwise_add) +REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d +REIGSTER_EMPTY_OP( + flatten_contiguous_range, + FlattenContiguousRangeOp); // flatten (flatten_contiguous_range) +REIGSTER_EMPTY_OP(matmul_v2, + MatmulV2Op); // To be customized: matmul (matmul_v2) +REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape +REIGSTER_EMPTY_OP(softmax_with_cross_entropy, + SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax + // (softmax_with_cross_entropy) +REIGSTER_EMPTY_OP(reduce_mean, + ReduceMeanOp); // To be customized: mean (reduce_mean) +REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2) +REIGSTER_EMPTY_OP(fill_constant, + FillConstantOp); // To be customized: full (fill_constant) +REIGSTER_EMPTY_OP(reduce_mean_grad, + ReduceMeanGradOp); // To be customized: reduce_mean_grad +REIGSTER_EMPTY_OP( + softmax_with_cross_entropy_grad, + SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad + // (softmax_with_cross_entropy_grad) +REIGSTER_EMPTY_OP( + elementwise_add_grad, + ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad) +REIGSTER_EMPTY_OP( + matmul_v2_grad, + MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad) +REIGSTER_EMPTY_OP( + flatten_contiguous_range_grad, + FlattenContiguousRangeGradOp); // flatten_grad + // (flatten_contiguous_range_grad) +REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad +REIGSTER_EMPTY_OP(batch_norm_grad, + BatchNormGradOp); // To be customized: batch_norm_grad +REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad +REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) +REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/dialect/pd_attribute.h b/paddle/fluid/dialect/pd_attribute.h index 75eed82dfc4..59481f1cedc 100644 --- a/paddle/fluid/dialect/pd_attribute.h +++ b/paddle/fluid/dialect/pd_attribute.h @@ -19,10 +19,6 @@ namespace paddle { namespace dialect { -#define GET_PD_DIALECT_ATTRIBUTE_LIST \ - IntArrayAttribute, ScalarAttribute, DataTypeAttribute, PlaceAttribute, \ - DataLayoutAttribute - class IntArrayAttribute : public ir::Attribute { public: using Attribute::Attribute; diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index 14aa2080a6e..9baeb4c1f9b 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -29,12 +29,12 @@ namespace paddle { namespace dialect { std::shared_ptr -ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { +ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { if (parameter->type().isa()) { VLOG(4) << "Convert a DenseTensor Parameter to a variable."; std::shared_ptr var = std::make_shared(); - phi::DenseTensor* tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); // Init DenseTensor auto dim = parameter->type().dyn_cast().dim(); phi::DenseTensorMeta meta( @@ -46,7 +46,7 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { parameter->type().dyn_cast().lod(), parameter->type().dyn_cast().offset()); tensor->set_meta(meta); - paddle::platform::DeviceContext* dev_ctx = + paddle::platform::DeviceContext *dev_ctx = paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace()); dev_ctx->Alloc(tensor, @@ -62,11 +62,11 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { } std::unique_ptr ParameterConvertInterface::VariableToParameter( - paddle::framework::Variable* var) { + paddle::framework::Variable *var) { if (var->IsType()) { - phi::DenseTensor* tensor = var->GetMutable(); + phi::DenseTensor *tensor = var->GetMutable(); // Get Meta - ir::IrContext* ctx = ir::IrContext::Instance(); + ir::IrContext *ctx = ir::IrContext::Instance(); ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); DenseTensorTypeStorage::Dim dims(tensor->dims().size()); std::copy(tensor->dims().Get(), @@ -76,7 +76,7 @@ std::unique_ptr ParameterConvertInterface::VariableToParameter( TransToIrDataLayout(tensor->layout()); DenseTensorTypeStorage::LoD lod = tensor->lod(); size_t offset = tensor->meta().offset; - void* data = tensor->data(); + void *data = tensor->data(); ir::Type dense_tensor_type = DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset); return std::make_unique( @@ -88,7 +88,7 @@ std::unique_ptr ParameterConvertInterface::VariableToParameter( } } -PaddleDialect::PaddleDialect(ir::IrContext* context) +PaddleDialect::PaddleDialect(ir::IrContext *context) : ir::Dialect(name(), context, ir::TypeId::get()) { initialize(); } @@ -136,11 +136,11 @@ void PaddleDialect::initialize() { FetchV2Op>(); } -void PaddleDialect::PrintType(ir::Type type, std::ostream& os) { +void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { DenseTensorType tensor_type = type.dyn_cast(); os << "tensor<"; - auto& dims = tensor_type.dim(); + auto &dims = tensor_type.dim(); for (auto d : dims) { os << d; os << "x"; diff --git a/paddle/fluid/dialect/pd_type.h b/paddle/fluid/dialect/pd_type.h index b644f12c1f8..e07951bbf93 100644 --- a/paddle/fluid/dialect/pd_type.h +++ b/paddle/fluid/dialect/pd_type.h @@ -19,8 +19,6 @@ namespace paddle { namespace dialect { -#define GET_PD_DIALECT_TYPE_LIST paddle::dialect::DenseTensorType - /// /// \brief Define built-in parametric types. /// diff --git a/paddle/ir/block.cc b/paddle/ir/block.cc new file mode 100644 index 00000000000..9603eff00f1 --- /dev/null +++ b/paddle/ir/block.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/block.h" + +namespace ir { +Block::~Block() { clear(); } + +void Block::clear() { + while (!empty()) { + ops_.back()->destroy(); + ops_.pop_back(); + } +} +} // namespace ir diff --git a/paddle/ir/block.h b/paddle/ir/block.h new file mode 100644 index 00000000000..3176bfe4c9e --- /dev/null +++ b/paddle/ir/block.h @@ -0,0 +1,54 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ir/operation.h" + +namespace ir { +class Block { + public: + using iterator = std::list::iterator; + using reverse_iterator = std::list::reverse_iterator; + + Block() = default; + ~Block(); + + bool empty() const { return ops_.empty(); } + size_t size() const { return ops_.size(); } + + iterator begin() { return ops_.begin(); } + iterator end() { return ops_.end(); } + reverse_iterator rbegin() { return ops_.rbegin(); } + reverse_iterator rend() { return ops_.rend(); } + + Operation *back() { return ops_.back(); } + Operation *front() { return ops_.front(); } + void push_back(Operation *op) { ops_.push_back(op); } + void push_front(Operation *op) { ops_.push_front(op); } + std::list::iterator insert( + std::list::const_iterator iterator, Operation *op) { + return ops_.insert(iterator, op); + } + void clear(); + + private: + Block(Block &) = delete; + void operator=(Block &) = delete; + + private: + std::list ops_; // owned +}; +} // namespace ir diff --git a/paddle/ir/builder.cc b/paddle/ir/builder.cc new file mode 100644 index 00000000000..842e8c63bcf --- /dev/null +++ b/paddle/ir/builder.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/builder.h" + +namespace ir { +Operation *Builder::insert(Operation *op) { + if (block_) { + block_->insert(insert_point_, op); + } else { + LOG(WARNING) << "Builder's Block is nullptr, insert failed."; + } + return op; +} + +/// Create an operation given the fields represented as an OperationState. +Operation *Builder::create(const OperationArgument &argument) { + return insert(Operation::create(argument)); +} + +/// Creates an operation with the given fields. +Operation *Builder::create(const std::vector &inputs, + const std::vector &output_types, + const AttributeMap &attribute, + ir::OpInfo op_info) { + OperationArgument argument(op_info, inputs, output_types, attribute); + return create(argument); +} + +} // namespace ir diff --git a/paddle/ir/builder.h b/paddle/ir/builder.h index cfca24cd5dd..c1762ca1441 100644 --- a/paddle/ir/builder.h +++ b/paddle/ir/builder.h @@ -16,7 +16,9 @@ #include +#include "paddle/ir/block.h" #include "paddle/ir/operation.h" +#include "paddle/ir/program.h" namespace ir { /// @@ -25,25 +27,47 @@ namespace ir { /// class Builder { public: - explicit Builder(IrContext *context) : context_(context) {} - explicit Builder(Operation *op) : Builder(op->ir_context()) {} + explicit Builder(IrContext *context, + Block *block, + Block::iterator insert_point) + : context_(context), block_(block), insert_point_(insert_point) {} + + static Builder AtBlockBegin(IrContext *context, Block *block) { + return Builder(context, block, block->begin()); + } + + static Builder AtBlockEnd(IrContext *context, Block *block) { + return Builder(context, block, block->end()); + } + + IrContext *context() const { return context_; } + + Block *block() const { return block_; } + + Operation *insert(Operation *op); + + /// Creates an operation given the fields represented as an OperationState. + Operation *create(const OperationArgument &argument); + + /// Creates an operation with the given fields. + Operation *create(const std::vector &inputs, + const std::vector &output_types, + const AttributeMap &attribute, + ir::OpInfo op_info); /// Create an operation of specific op type at the current insertion point. template OpTy create(Args &&...args) { OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OpTy::build(*this, argument, std::forward(args)...); - Operation *op = Operation::create(argument); + Operation *op = create(argument); return op->dyn_cast(); } private: IrContext *context_; - // The current op list this builder is inserting into. - // After the design of the block data structure is completed, - // this member will be replaced by the block. - std::list *op_list_ = nullptr; + Block *block_ = nullptr; // The insertion point within the list that this builder is inserting before. - std::list::iterator insertPoint; + Block::iterator insert_point_; }; } // namespace ir diff --git a/paddle/ir/builtin_attribute.h b/paddle/ir/builtin_attribute.h index 93eb8599b20..468bf1be6ab 100644 --- a/paddle/ir/builtin_attribute.h +++ b/paddle/ir/builtin_attribute.h @@ -19,13 +19,6 @@ #include "paddle/ir/utils.h" namespace ir { -/// -/// \brief All built-in attributes. -/// -#define GET_BUILT_IN_ATTRIBUTE_LIST \ - StrAttribute, BoolAttribute, FloatAttribute, DoubleAttribute, \ - Int32_tAttribute, Int64_tAttribute, ArrayAttribute - class StrAttribute : public Attribute { public: using Attribute::Attribute; diff --git a/paddle/ir/builtin_type.h b/paddle/ir/builtin_type.h index 803638750cb..cb7dd8a853d 100644 --- a/paddle/ir/builtin_type.h +++ b/paddle/ir/builtin_type.h @@ -18,15 +18,6 @@ #include "paddle/ir/type.h" namespace ir { -/// -/// \brief This macro is used to get a list of all built-in types in this file. -/// The built-in Dialect will use this macro to quickly register all built-in -/// types. -/// -#define GET_BUILT_IN_TYPE_LIST \ - BFloat16Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, \ - Int32Type, Int64Type, BoolType, VectorType - /// /// \brief Define built-in parameterless types. Please add the necessary /// interface functions for built-in types through the macro diff --git a/paddle/ir/op_base.h b/paddle/ir/op_base.h index b245cb7850f..47fa243b1c4 100644 --- a/paddle/ir/op_base.h +++ b/paddle/ir/op_base.h @@ -66,18 +66,18 @@ class InterfaceValue { class OpBase { public: - explicit OpBase(const Operation *operation) : operation_(operation) {} + explicit OpBase(Operation *operation) : operation_(operation) {} - const Operation *operation() const { return operation_; } + Operation *operation() const { return operation_; } explicit operator bool() const { return operation() != nullptr; } - operator const Operation *() const { return operation_; } + operator Operation *() const { return operation_; } - const Operation *operator->() const { return operation_; } + Operation *operator->() const { return operation_; } private: - const Operation *operation_; // Not owned + Operation *operation_; // Not owned }; /// @@ -86,11 +86,11 @@ class OpBase { template class OpTraitBase : public OpBase { public: - explicit OpTraitBase(const Operation *op) : OpBase(op) {} + explicit OpTraitBase(Operation *op) : OpBase(op) {} static TypeId GetTraitId() { return TypeId::get(); } - static ConcreteTrait dyn_cast(const Operation *op) { + static ConcreteTrait dyn_cast(Operation *op) { if (op->HasTrait()) { return ConcreteTrait(op); } @@ -104,13 +104,11 @@ class OpTraitBase : public OpBase { template class OpInterfaceBase : public OpBase { public: - // explicit OpInterfaceBase(Operation *op) : OpBase(op) {} - - explicit OpInterfaceBase(const Operation *op) : OpBase(op) {} + explicit OpInterfaceBase(Operation *op) : OpBase(op) {} static TypeId GetInterfaceId() { return TypeId::get(); } - static ConcreteInterface dyn_cast(const Operation *op) { + static ConcreteInterface dyn_cast(Operation *op) { if (op->HasInterface()) { return ConcreteInterface( op, op->op_info().GetInterfaceImpl()); @@ -183,7 +181,7 @@ class Op : public OpBase { using InterfaceList = typename Filter>::Type; - static ConcreteOp dyn_cast(const Operation *op) { + static ConcreteOp dyn_cast(Operation *op) { if (op->op_info().id() == TypeId::get()) { return ConcreteOp(op); } diff --git a/paddle/ir/operation.h b/paddle/ir/operation.h index a62d248edb3..0b4f03c8e6a 100644 --- a/paddle/ir/operation.h +++ b/paddle/ir/operation.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/ir/op_info.h" #include "paddle/ir/operation_utils.h" #include "paddle/ir/type.h" @@ -61,7 +62,7 @@ class alignas(8) Operation final { std::string op_name() const; template - T dyn_cast() const { + T dyn_cast() { return CastUtil::call(this); } @@ -89,7 +90,7 @@ class alignas(8) Operation final { template struct CastUtil { - static T call(const Operation *op) { + static T call(Operation *op) { throw("Can't dyn_cast to T, T should be a Op or Trait or Interface"); } }; @@ -98,7 +99,7 @@ class alignas(8) Operation final { struct CastUtil< T, typename std::enable_if::value>::type> { - static T call(const Operation *op) { return T::dyn_cast(op); } + static T call(Operation *op) { return T::dyn_cast(op); } }; AttributeMap attribute_; diff --git a/paddle/ir/printer.cc b/paddle/ir/printer.cc index 421dedabb3a..fbb1673b569 100644 --- a/paddle/ir/printer.cc +++ b/paddle/ir/printer.cc @@ -95,9 +95,11 @@ class ProgramPrinter : public Printer { explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {} void Print(ir::Program& program) { - for (auto* op : program.ops()) { - PrintOperation(op); + auto iterator = program.block()->begin(); + while (iterator != program.block()->end()) { + PrintOperation(*iterator); os << newline; + iterator++; } } diff --git a/paddle/ir/program.cc b/paddle/ir/program.cc index 4caa7a80513..6b8524cee1c 100644 --- a/paddle/ir/program.cc +++ b/paddle/ir/program.cc @@ -16,14 +16,10 @@ #include "paddle/ir/ir_context.h" namespace ir { -Program::~Program() { - for (auto op : ops_) { - op->destroy(); - } -} +Program::~Program() = default; void Program::InsertOp(Operation* op) { - ops_.push_back(op); + block_.push_back(op); op->set_parent_program(this); } diff --git a/paddle/ir/program.h b/paddle/ir/program.h index bcae617b2b9..5115034755e 100644 --- a/paddle/ir/program.h +++ b/paddle/ir/program.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/ir/block.h" #include "paddle/ir/builtin_attribute.h" #include "paddle/ir/operation.h" #include "paddle/ir/parameter.h" @@ -34,7 +35,7 @@ class Program { public: ~Program(); - std::list ops() const { return ops_; } + Block* block() { return &block_; } size_t parameters_num() const { return parameters_.size(); } @@ -51,8 +52,7 @@ class Program { void SetParameter(std::string name, std::unique_ptr&& parameter); private: - std::list ops_; // owned - + Block block_; std::unordered_map> parameters_; }; diff --git a/test/cpp/ir/ir_op_test.cc b/test/cpp/ir/ir_op_test.cc index aa580b0b883..69ea38c487c 100644 --- a/test/cpp/ir/ir_op_test.cc +++ b/test/cpp/ir/ir_op_test.cc @@ -14,6 +14,7 @@ #include +#include "paddle/ir/builder.h" #include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_type.h" #include "paddle/ir/dialect.h" @@ -23,7 +24,7 @@ /// \brief Define built-in Trait, derived from OpTraitBase. class ReadOnlyTrait : public ir::OpTraitBase { public: - explicit ReadOnlyTrait(const ir::Operation *op) + explicit ReadOnlyTrait(ir::Operation *op) : ir::OpTraitBase(op) {} }; @@ -34,14 +35,14 @@ class ReadOnlyTrait : public ir::OpTraitBase { class InferShapeInterface : public ir::OpInterfaceBase { public: struct Concept { - explicit Concept(void (*infer_shape)(const ir::Operation *)) + explicit Concept(void (*infer_shape)(ir::Operation *)) : infer_shape_(infer_shape) {} - void (*infer_shape_)(const ir::Operation *); + void (*infer_shape_)(ir::Operation *); }; template struct Model : public Concept { - static void InferShape(const ir::Operation *op) { + static void InferShape(ir::Operation *op) { ConcreteOp concret_op = ConcreteOp(op); if (concret_op == nullptr) throw("concret_op is nullptr"); concret_op.InferShape(); @@ -53,7 +54,7 @@ class InferShapeInterface : public ir::OpInterfaceBase { } }; - InferShapeInterface(const ir::Operation *op, Concept *impl) + InferShapeInterface(ir::Operation *op, Concept *impl) : ir::OpInterfaceBase(op), impl_(impl) {} void InferShape() { impl_->infer_shape_(operation()); } @@ -62,6 +63,18 @@ class InferShapeInterface : public ir::OpInterfaceBase { Concept *impl_; }; +ir::AttributeMap CreateAttributeMap(std::vector attribute_names, + std::vector attributes) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + // Define op1. class Operation1 : public ir::Op { public: @@ -81,6 +94,22 @@ class Operation1 : public ir::Op { throw("Type of attribute: parameter_name is not right."); } } + static void build(const ir::Builder &builder, + ir::OperationArgument &argument) { // NOLINT + std::vector inputs = {}; + std::vector output_types = { + ir::Float32Type::get(builder.context())}; + std::unordered_map attributes = + CreateAttributeMap({"op1_attr1", "op1_attr2"}, + {"op1_attr1", "op1_attr2"}); + argument.addOperands::iterator>(inputs.begin(), + inputs.end()); + argument.addTypes::iterator>(output_types.begin(), + output_types.end()); + argument.addAttributes< + std::unordered_map::iterator>( + attributes.begin(), attributes.end()); + } }; const char *Operation1::attributes_name[attributes_num] = {"op1_attr1", "op1_attr2"}; @@ -105,9 +134,7 @@ class Operation2 throw("Type of attribute: parameter_name is not right."); } } - static void InferShape() { - std::cout << "This is op2's InferShape interface." << std::endl; - } + static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; } }; const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", "op2_attr2"}; @@ -125,23 +152,11 @@ class TestDialect : public ir::Dialect { void initialize() { RegisterOps(); } }; -ir::AttributeMap CreateAttributeMap(std::vector attribute_names, - std::vector attributes) { - ir::IrContext *ctx = ir::IrContext::Instance(); - ir::AttributeMap attr_map; - for (size_t i = 0; i < attribute_names.size(); i++) { - ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]); - attr_map.insert( - std::pair(attribute_names[i], attr_value)); - } - return attr_map; -} - TEST(op_test, op_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. ir::IrContext *ctx = ir::IrContext::Instance(); ir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); - std::cout << test_dialect << std::endl; + EXPECT_EQ(test_dialect != nullptr, true); // (2) Get registered operations. std::string op1_name = Operation1::name(); @@ -158,18 +173,18 @@ TEST(op_test, op_test) { // (3) Test uses for op. std::vector op_inputs = {}; std::vector op_output_types = {ir::Float32Type::get(ctx)}; - ir::Operation *op = + ir::Operation *op2 = ir::Operation::create(op_inputs, op_output_types, CreateAttributeMap({"op2_attr1", "op2_attr2"}, {"op2_attr1", "op2_attr2"}), op2_info); - ReadOnlyTrait trait = op->dyn_cast(); - EXPECT_EQ(trait.operation(), op); - InferShapeInterface interface = op->dyn_cast(); + ReadOnlyTrait trait = op2->dyn_cast(); + EXPECT_EQ(trait.operation(), op2); + InferShapeInterface interface = op2->dyn_cast(); interface.InferShape(); - Operation2 Op2 = op->dyn_cast(); - EXPECT_EQ(Op2.operation(), op); - op->destroy(); + Operation2 Op2 = op2->dyn_cast(); + EXPECT_EQ(Op2.operation(), op2); + op2->destroy(); } diff --git a/test/cpp/ir/ir_program_test.cc b/test/cpp/ir/ir_program_test.cc index 7c6c9acaf52..9fb72fec13c 100644 --- a/test/cpp/ir/ir_program_test.cc +++ b/test/cpp/ir/ir_program_test.cc @@ -57,7 +57,7 @@ TEST(program_test, program) { // (2) Create an empty program object ir::Program program; // ir::Program *program = new ir::Program(); - EXPECT_EQ(program.ops().size() == 0, true); + EXPECT_EQ(program.block()->size() == 0, true); // (3) Create a float32 DenseTensor Parameter and save into Program ir::Type fp32_dtype = ir::Float32Type::get(ctx); @@ -207,8 +207,7 @@ TEST(program_test, program) { program.SetParameter("c", std::move(parameter_c)); // (8) Traverse Program - std::list ops = program.ops(); - EXPECT_EQ(ops.size() == 4, true); + EXPECT_EQ(program.block()->size() == 4, true); EXPECT_EQ(program.parameters_num() == 3, true); } @@ -220,7 +219,7 @@ TEST(program_test, slice_combine_test) { // (2) Create an empty program object ir::Program program; // ir::Program *program = new ir::Program(); - EXPECT_EQ(program.ops().size() == 0, true); + EXPECT_EQ(program.block()->size() == 0, true); // (3) Create a float32 DenseTensor Parameter and save into Program ir::Type fp32_dtype = ir::Float32Type::get(ctx); @@ -267,6 +266,5 @@ TEST(program_test, slice_combine_test) { program.InsertOp(slice_op); // (8) Traverse Program - std::list ops = program.ops(); - EXPECT_EQ(ops.size() == 4, true); + EXPECT_EQ(program.block()->size() == 4, true); } diff --git a/test/cpp/ir/ir_value_test.cc b/test/cpp/ir/ir_value_test.cc index e0f0d83312a..00a35db2cab 100644 --- a/test/cpp/ir/ir_value_test.cc +++ b/test/cpp/ir/ir_value_test.cc @@ -43,7 +43,7 @@ TEST(value_test, value_test) { op1_output_types, CreateAttributeMap("op1_name", "op1_attr"), nullptr); - std::cout << op1->print() << std::endl; + VLOG(0) << op1->print(); // 2. Construct OP2: b = OP2(); std::vector op2_inputs = {}; std::vector op2_output_types = {ir::Float32Type::get(ctx)}; @@ -52,7 +52,7 @@ TEST(value_test, value_test) { op2_output_types, CreateAttributeMap("op2_name", "op2_attr"), nullptr); - std::cout << op2->print() << std::endl; + VLOG(0) << op2->print() << std::endl; // 3. Construct OP3: c = OP3(a, b); std::vector op3_inputs = {op1->GetResultByIndex(0), op2->GetResultByIndex(0)}; @@ -62,7 +62,7 @@ TEST(value_test, value_test) { op3_output_types, CreateAttributeMap("op3_name", "op3_attr"), nullptr); - std::cout << op3->print() << std::endl; + VLOG(0) << op3->print() << std::endl; // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); std::vector op4_inputs = {op1->GetResultByIndex(0), op3->GetResultByIndex(0)}; @@ -75,7 +75,7 @@ TEST(value_test, value_test) { op4_output_types, CreateAttributeMap("op4_name", "op4_attr"), nullptr); - std::cout << op4->print() << std::endl; + VLOG(0) << op4->print() << std::endl; // Test 1: EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1); @@ -103,12 +103,12 @@ TEST(value_test, value_test) { EXPECT_EQ(iter.owner(), op3); // destroy - std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; + VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; op4->destroy(); - std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; + VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; op3->destroy(); - std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; + VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; op2->destroy(); - std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; + VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl; op1->destroy(); } -- GitLab