未验证 提交 3143d8bf 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine Builder (#54052)

* refine code

* delete some unused code

* refine code of build

* refine code of build

* add block

* refine builder

* refine code

* refine code by comment

* fix compiler bug
上级 1549fbd3
...@@ -25,40 +25,60 @@ namespace dialect { ...@@ -25,40 +25,60 @@ namespace dialect {
class className : public ir::Op<className> { \ class className : public ir::Op<className> { \
public: \ public: \
static const char *name() { return OPNAME(op_name); } \ static const char *name() { return OPNAME(op_name); } \
static const char **attributes_name; \ static constexpr const char **attributes_name = nullptr; \
static constexpr uint32_t attributes_num = 0; \ static constexpr uint32_t attributes_num = 0; \
static void verify(const std::vector<ir::OpResult> &inputs, \ static void verify(const std::vector<ir::OpResult> &inputs, \
const std::vector<ir::Type> &outputs, \ const std::vector<ir::Type> &outputs, \
const ir::AttributeMap &attributes) { \ const ir::AttributeMap &attributes) { \
LOG(WARNING) << "This is a fake verify"; \ LOG(WARNING) << "This is a fake verify"; \
} \ } \
}; \ };
const char **className::attributes_name = nullptr;
REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // TODO(zhangbo): As operators are supplemented and defined, they are gradually
REIGSTER_EMPTY_OP(feed, FeedOp); // removed.
REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d
REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed
REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp); REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm
REIGSTER_EMPTY_OP(pool2d, Pool2DOp); REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_
REIGSTER_EMPTY_OP(flatten_contiguous_range, FlattenContiguousRangeOp); REIGSTER_EMPTY_OP(elementwise_add,
REIGSTER_EMPTY_OP(matmul_v2, MatmulV2Op); ElementwiseAddOp); // To be customized: add (elementwise_add)
REIGSTER_EMPTY_OP(reshape2, Reshape2Op); REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d
REIGSTER_EMPTY_OP(softmax_with_cross_entropy, SoftmaxWithCrossEntropyOp); REIGSTER_EMPTY_OP(
REIGSTER_EMPTY_OP(reduce_mean, ReduceMeanOp); flatten_contiguous_range,
REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); FlattenContiguousRangeOp); // flatten (flatten_contiguous_range)
REIGSTER_EMPTY_OP(fill_constant, FillConstantOp); REIGSTER_EMPTY_OP(matmul_v2,
REIGSTER_EMPTY_OP(reduce_mean_grad, ReduceMeanGradOp); MatmulV2Op); // To be customized: matmul (matmul_v2)
REIGSTER_EMPTY_OP(softmax_with_cross_entropy_grad, REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape
SoftmaxWithCrossEntropyGradOp); REIGSTER_EMPTY_OP(softmax_with_cross_entropy,
REIGSTER_EMPTY_OP(elementwise_add_grad, ElementwiseAddGradOp); SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax
REIGSTER_EMPTY_OP(matmul_v2_grad, MatmulV2GradOp); // (softmax_with_cross_entropy)
REIGSTER_EMPTY_OP(flatten_contiguous_range_grad, FlattenContiguousRangeGradOp); REIGSTER_EMPTY_OP(reduce_mean,
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); ReduceMeanOp); // To be customized: mean (reduce_mean)
REIGSTER_EMPTY_OP(batch_norm_grad, BatchNormGradOp); REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2)
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); REIGSTER_EMPTY_OP(fill_constant,
REIGSTER_EMPTY_OP(sum, SumOp); FillConstantOp); // To be customized: full (fill_constant)
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); REIGSTER_EMPTY_OP(reduce_mean_grad,
ReduceMeanGradOp); // To be customized: reduce_mean_grad
REIGSTER_EMPTY_OP(
softmax_with_cross_entropy_grad,
SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad
// (softmax_with_cross_entropy_grad)
REIGSTER_EMPTY_OP(
elementwise_add_grad,
ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad)
REIGSTER_EMPTY_OP(
matmul_v2_grad,
MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad)
REIGSTER_EMPTY_OP(
flatten_contiguous_range_grad,
FlattenContiguousRangeGradOp); // flatten_grad
// (flatten_contiguous_range_grad)
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad
REIGSTER_EMPTY_OP(batch_norm_grad,
BatchNormGradOp); // To be customized: batch_norm_grad
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad
REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -19,10 +19,6 @@ ...@@ -19,10 +19,6 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
#define GET_PD_DIALECT_ATTRIBUTE_LIST \
IntArrayAttribute, ScalarAttribute, DataTypeAttribute, PlaceAttribute, \
DataLayoutAttribute
class IntArrayAttribute : public ir::Attribute { class IntArrayAttribute : public ir::Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
......
...@@ -29,12 +29,12 @@ ...@@ -29,12 +29,12 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
std::shared_ptr<paddle::framework::Variable> std::shared_ptr<paddle::framework::Variable>
ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
if (parameter->type().isa<DenseTensorType>()) { if (parameter->type().isa<DenseTensorType>()) {
VLOG(4) << "Convert a DenseTensor Parameter to a variable."; VLOG(4) << "Convert a DenseTensor Parameter to a variable.";
std::shared_ptr<paddle::framework::Variable> var = std::shared_ptr<paddle::framework::Variable> var =
std::make_shared<paddle::framework::Variable>(); std::make_shared<paddle::framework::Variable>();
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Init DenseTensor // Init DenseTensor
auto dim = parameter->type().dyn_cast<DenseTensorType>().dim(); auto dim = parameter->type().dyn_cast<DenseTensorType>().dim();
phi::DenseTensorMeta meta( phi::DenseTensorMeta meta(
...@@ -46,7 +46,7 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { ...@@ -46,7 +46,7 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) {
parameter->type().dyn_cast<DenseTensorType>().lod(), parameter->type().dyn_cast<DenseTensorType>().lod(),
parameter->type().dyn_cast<DenseTensorType>().offset()); parameter->type().dyn_cast<DenseTensorType>().offset());
tensor->set_meta(meta); tensor->set_meta(meta);
paddle::platform::DeviceContext* dev_ctx = paddle::platform::DeviceContext *dev_ctx =
paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace()); paddle::platform::CPUPlace());
dev_ctx->Alloc(tensor, dev_ctx->Alloc(tensor,
...@@ -62,11 +62,11 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) { ...@@ -62,11 +62,11 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) {
} }
std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter( std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
paddle::framework::Variable* var) { paddle::framework::Variable *var) {
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Get Meta // Get Meta
ir::IrContext* ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
DenseTensorTypeStorage::Dim dims(tensor->dims().size()); DenseTensorTypeStorage::Dim dims(tensor->dims().size());
std::copy(tensor->dims().Get(), std::copy(tensor->dims().Get(),
...@@ -76,7 +76,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter( ...@@ -76,7 +76,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
TransToIrDataLayout(tensor->layout()); TransToIrDataLayout(tensor->layout());
DenseTensorTypeStorage::LoD lod = tensor->lod(); DenseTensorTypeStorage::LoD lod = tensor->lod();
size_t offset = tensor->meta().offset; size_t offset = tensor->meta().offset;
void* data = tensor->data(); void *data = tensor->data();
ir::Type dense_tensor_type = ir::Type dense_tensor_type =
DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset); DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset);
return std::make_unique<ir::Parameter>( return std::make_unique<ir::Parameter>(
...@@ -88,7 +88,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter( ...@@ -88,7 +88,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
} }
} }
PaddleDialect::PaddleDialect(ir::IrContext* context) PaddleDialect::PaddleDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<PaddleDialect>()) { : ir::Dialect(name(), context, ir::TypeId::get<PaddleDialect>()) {
initialize(); initialize();
} }
...@@ -136,11 +136,11 @@ void PaddleDialect::initialize() { ...@@ -136,11 +136,11 @@ void PaddleDialect::initialize() {
FetchV2Op>(); FetchV2Op>();
} }
void PaddleDialect::PrintType(ir::Type type, std::ostream& os) { void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>(); DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<"; os << "tensor<";
auto& dims = tensor_type.dim(); auto &dims = tensor_type.dim();
for (auto d : dims) { for (auto d : dims) {
os << d; os << d;
os << "x"; os << "x";
......
...@@ -19,8 +19,6 @@ ...@@ -19,8 +19,6 @@
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
#define GET_PD_DIALECT_TYPE_LIST paddle::dialect::DenseTensorType
/// ///
/// \brief Define built-in parametric types. /// \brief Define built-in parametric types.
/// ///
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/block.h"
namespace ir {
Block::~Block() { clear(); }
void Block::clear() {
while (!empty()) {
ops_.back()->destroy();
ops_.pop_back();
}
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include "paddle/ir/operation.h"
namespace ir {
class Block {
public:
using iterator = std::list<Operation *>::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator;
Block() = default;
~Block();
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }
iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); }
Operation *back() { return ops_.back(); }
Operation *front() { return ops_.front(); }
void push_back(Operation *op) { ops_.push_back(op); }
void push_front(Operation *op) { ops_.push_front(op); }
std::list<Operation *>::iterator insert(
std::list<Operation *>::const_iterator iterator, Operation *op) {
return ops_.insert(iterator, op);
}
void clear();
private:
Block(Block &) = delete;
void operator=(Block &) = delete;
private:
std::list<Operation *> ops_; // owned
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builder.h"
namespace ir {
Operation *Builder::insert(Operation *op) {
if (block_) {
block_->insert(insert_point_, op);
} else {
LOG(WARNING) << "Builder's Block is nullptr, insert failed.";
}
return op;
}
/// Create an operation given the fields represented as an OperationState.
Operation *Builder::create(const OperationArgument &argument) {
return insert(Operation::create(argument));
}
/// Creates an operation with the given fields.
Operation *Builder::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info) {
OperationArgument argument(op_info, inputs, output_types, attribute);
return create(argument);
}
} // namespace ir
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
#include <list> #include <list>
#include "paddle/ir/block.h"
#include "paddle/ir/operation.h" #include "paddle/ir/operation.h"
#include "paddle/ir/program.h"
namespace ir { namespace ir {
/// ///
...@@ -25,25 +27,47 @@ namespace ir { ...@@ -25,25 +27,47 @@ namespace ir {
/// ///
class Builder { class Builder {
public: public:
explicit Builder(IrContext *context) : context_(context) {} explicit Builder(IrContext *context,
explicit Builder(Operation *op) : Builder(op->ir_context()) {} Block *block,
Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {}
static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin());
}
static Builder AtBlockEnd(IrContext *context, Block *block) {
return Builder(context, block, block->end());
}
IrContext *context() const { return context_; }
Block *block() const { return block_; }
Operation *insert(Operation *op);
/// Creates an operation given the fields represented as an OperationState.
Operation *create(const OperationArgument &argument);
/// Creates an operation with the given fields.
Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info);
/// Create an operation of specific op type at the current insertion point. /// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args> template <typename OpTy, typename... Args>
OpTy create(Args &&...args) { OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name())); OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...); OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument); Operation *op = create(argument);
return op->dyn_cast<OpTy>(); return op->dyn_cast<OpTy>();
} }
private: private:
IrContext *context_; IrContext *context_;
// The current op list this builder is inserting into. Block *block_ = nullptr;
// After the design of the block data structure is completed,
// this member will be replaced by the block.
std::list<Operation *> *op_list_ = nullptr;
// The insertion point within the list that this builder is inserting before. // The insertion point within the list that this builder is inserting before.
std::list<Operation *>::iterator insertPoint; Block::iterator insert_point_;
}; };
} // namespace ir } // namespace ir
...@@ -19,13 +19,6 @@ ...@@ -19,13 +19,6 @@
#include "paddle/ir/utils.h" #include "paddle/ir/utils.h"
namespace ir { namespace ir {
///
/// \brief All built-in attributes.
///
#define GET_BUILT_IN_ATTRIBUTE_LIST \
StrAttribute, BoolAttribute, FloatAttribute, DoubleAttribute, \
Int32_tAttribute, Int64_tAttribute, ArrayAttribute
class StrAttribute : public Attribute { class StrAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
......
...@@ -18,15 +18,6 @@ ...@@ -18,15 +18,6 @@
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
namespace ir { namespace ir {
///
/// \brief This macro is used to get a list of all built-in types in this file.
/// The built-in Dialect will use this macro to quickly register all built-in
/// types.
///
#define GET_BUILT_IN_TYPE_LIST \
BFloat16Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, \
Int32Type, Int64Type, BoolType, VectorType
/// ///
/// \brief Define built-in parameterless types. Please add the necessary /// \brief Define built-in parameterless types. Please add the necessary
/// interface functions for built-in types through the macro /// interface functions for built-in types through the macro
......
...@@ -66,18 +66,18 @@ class InterfaceValue { ...@@ -66,18 +66,18 @@ class InterfaceValue {
class OpBase { class OpBase {
public: public:
explicit OpBase(const Operation *operation) : operation_(operation) {} explicit OpBase(Operation *operation) : operation_(operation) {}
const Operation *operation() const { return operation_; } Operation *operation() const { return operation_; }
explicit operator bool() const { return operation() != nullptr; } explicit operator bool() const { return operation() != nullptr; }
operator const Operation *() const { return operation_; } operator Operation *() const { return operation_; }
const Operation *operator->() const { return operation_; } Operation *operator->() const { return operation_; }
private: private:
const Operation *operation_; // Not owned Operation *operation_; // Not owned
}; };
/// ///
...@@ -86,11 +86,11 @@ class OpBase { ...@@ -86,11 +86,11 @@ class OpBase {
template <class ConcreteTrait> template <class ConcreteTrait>
class OpTraitBase : public OpBase { class OpTraitBase : public OpBase {
public: public:
explicit OpTraitBase(const Operation *op) : OpBase(op) {} explicit OpTraitBase(Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); } static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(const Operation *op) { static ConcreteTrait dyn_cast(Operation *op) {
if (op->HasTrait<ConcreteTrait>()) { if (op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op); return ConcreteTrait(op);
} }
...@@ -104,13 +104,11 @@ class OpTraitBase : public OpBase { ...@@ -104,13 +104,11 @@ class OpTraitBase : public OpBase {
template <typename ConcreteInterface> template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase { class OpInterfaceBase : public OpBase {
public: public:
// explicit OpInterfaceBase(Operation *op) : OpBase(op) {} explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); } static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(const Operation *op) { static ConcreteInterface dyn_cast(Operation *op) {
if (op->HasInterface<ConcreteInterface>()) { if (op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface( return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>()); op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
...@@ -183,7 +181,7 @@ class Op : public OpBase { ...@@ -183,7 +181,7 @@ class Op : public OpBase {
using InterfaceList = using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type; typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(const Operation *op) { static ConcreteOp dyn_cast(Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) { if (op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op); return ConcreteOp(op);
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <iostream>
#include "paddle/ir/op_info.h" #include "paddle/ir/op_info.h"
#include "paddle/ir/operation_utils.h" #include "paddle/ir/operation_utils.h"
#include "paddle/ir/type.h" #include "paddle/ir/type.h"
...@@ -61,7 +62,7 @@ class alignas(8) Operation final { ...@@ -61,7 +62,7 @@ class alignas(8) Operation final {
std::string op_name() const; std::string op_name() const;
template <typename T> template <typename T>
T dyn_cast() const { T dyn_cast() {
return CastUtil<T>::call(this); return CastUtil<T>::call(this);
} }
...@@ -89,7 +90,7 @@ class alignas(8) Operation final { ...@@ -89,7 +90,7 @@ class alignas(8) Operation final {
template <typename T, typename Enabler = void> template <typename T, typename Enabler = void>
struct CastUtil { struct CastUtil {
static T call(const Operation *op) { static T call(Operation *op) {
throw("Can't dyn_cast to T, T should be a Op or Trait or Interface"); throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
} }
}; };
...@@ -98,7 +99,7 @@ class alignas(8) Operation final { ...@@ -98,7 +99,7 @@ class alignas(8) Operation final {
struct CastUtil< struct CastUtil<
T, T,
typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> { typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
static T call(const Operation *op) { return T::dyn_cast(op); } static T call(Operation *op) { return T::dyn_cast(op); }
}; };
AttributeMap attribute_; AttributeMap attribute_;
......
...@@ -95,9 +95,11 @@ class ProgramPrinter : public Printer { ...@@ -95,9 +95,11 @@ class ProgramPrinter : public Printer {
explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {} explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {}
void Print(ir::Program& program) { void Print(ir::Program& program) {
for (auto* op : program.ops()) { auto iterator = program.block()->begin();
PrintOperation(op); while (iterator != program.block()->end()) {
PrintOperation(*iterator);
os << newline; os << newline;
iterator++;
} }
} }
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
namespace ir { namespace ir {
Program::~Program() { Program::~Program() = default;
for (auto op : ops_) {
op->destroy();
}
}
void Program::InsertOp(Operation* op) { void Program::InsertOp(Operation* op) {
ops_.push_back(op); block_.push_back(op);
op->set_parent_program(this); op->set_parent_program(this);
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <list> #include <list>
#include <unordered_map> #include <unordered_map>
#include "paddle/ir/block.h"
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/operation.h" #include "paddle/ir/operation.h"
#include "paddle/ir/parameter.h" #include "paddle/ir/parameter.h"
...@@ -34,7 +35,7 @@ class Program { ...@@ -34,7 +35,7 @@ class Program {
public: public:
~Program(); ~Program();
std::list<Operation*> ops() const { return ops_; } Block* block() { return &block_; }
size_t parameters_num() const { return parameters_.size(); } size_t parameters_num() const { return parameters_.size(); }
...@@ -51,8 +52,7 @@ class Program { ...@@ -51,8 +52,7 @@ class Program {
void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter); void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter);
private: private:
std::list<Operation*> ops_; // owned Block block_;
std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_; std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_;
}; };
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/ir/builder.h"
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h" #include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h" #include "paddle/ir/dialect.h"
...@@ -23,7 +24,7 @@ ...@@ -23,7 +24,7 @@
/// \brief Define built-in Trait, derived from OpTraitBase. /// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> { class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
public: public:
explicit ReadOnlyTrait(const ir::Operation *op) explicit ReadOnlyTrait(ir::Operation *op)
: ir::OpTraitBase<ReadOnlyTrait>(op) {} : ir::OpTraitBase<ReadOnlyTrait>(op) {}
}; };
...@@ -34,14 +35,14 @@ class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> { ...@@ -34,14 +35,14 @@ class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> { class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
public: public:
struct Concept { struct Concept {
explicit Concept(void (*infer_shape)(const ir::Operation *)) explicit Concept(void (*infer_shape)(ir::Operation *))
: infer_shape_(infer_shape) {} : infer_shape_(infer_shape) {}
void (*infer_shape_)(const ir::Operation *); void (*infer_shape_)(ir::Operation *);
}; };
template <class ConcreteOp> template <class ConcreteOp>
struct Model : public Concept { struct Model : public Concept {
static void InferShape(const ir::Operation *op) { static void InferShape(ir::Operation *op) {
ConcreteOp concret_op = ConcreteOp(op); ConcreteOp concret_op = ConcreteOp(op);
if (concret_op == nullptr) throw("concret_op is nullptr"); if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape(); concret_op.InferShape();
...@@ -53,7 +54,7 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> { ...@@ -53,7 +54,7 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
} }
}; };
InferShapeInterface(const ir::Operation *op, Concept *impl) InferShapeInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {} : ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}
void InferShape() { impl_->infer_shape_(operation()); } void InferShape() { impl_->infer_shape_(operation()); }
...@@ -62,6 +63,18 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> { ...@@ -62,6 +63,18 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
Concept *impl_; Concept *impl_;
}; };
ir::AttributeMap CreateAttributeMap(std::vector<std::string> attribute_names,
std::vector<std::string> attributes) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::AttributeMap attr_map;
for (size_t i = 0; i < attribute_names.size(); i++) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]);
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_names[i], attr_value));
}
return attr_map;
}
// Define op1. // Define op1.
class Operation1 : public ir::Op<Operation1> { class Operation1 : public ir::Op<Operation1> {
public: public:
...@@ -81,6 +94,22 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -81,6 +94,22 @@ class Operation1 : public ir::Op<Operation1> {
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
} }
} }
static void build(const ir::Builder &builder,
ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = {
ir::Float32Type::get(builder.context())};
std::unordered_map<std::string, ir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"});
argument.addOperands<std::vector<ir::OpResult>::iterator>(inputs.begin(),
inputs.end());
argument.addTypes<std::vector<ir::Type>::iterator>(output_types.begin(),
output_types.end());
argument.addAttributes<
std::unordered_map<std::string, ir::Attribute>::iterator>(
attributes.begin(), attributes.end());
}
}; };
const char *Operation1::attributes_name[attributes_num] = {"op1_attr1", const char *Operation1::attributes_name[attributes_num] = {"op1_attr1",
"op1_attr2"}; "op1_attr2"};
...@@ -105,9 +134,7 @@ class Operation2 ...@@ -105,9 +134,7 @@ class Operation2
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
} }
} }
static void InferShape() { static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; }
std::cout << "This is op2's InferShape interface." << std::endl;
}
}; };
const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"}; "op2_attr2"};
...@@ -125,23 +152,11 @@ class TestDialect : public ir::Dialect { ...@@ -125,23 +152,11 @@ class TestDialect : public ir::Dialect {
void initialize() { RegisterOps<Operation1, Operation2>(); } void initialize() { RegisterOps<Operation1, Operation2>(); }
}; };
ir::AttributeMap CreateAttributeMap(std::vector<std::string> attribute_names,
std::vector<std::string> attributes) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::AttributeMap attr_map;
for (size_t i = 0; i < attribute_names.size(); i++) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]);
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_names[i], attr_value));
}
return attr_map;
}
TEST(op_test, op_test) { TEST(op_test, op_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext. // (1) Register Dialect, Operation1, Operation2 into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>(); ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
std::cout << test_dialect << std::endl; EXPECT_EQ(test_dialect != nullptr, true);
// (2) Get registered operations. // (2) Get registered operations.
std::string op1_name = Operation1::name(); std::string op1_name = Operation1::name();
...@@ -158,18 +173,18 @@ TEST(op_test, op_test) { ...@@ -158,18 +173,18 @@ TEST(op_test, op_test) {
// (3) Test uses for op. // (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {}; std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op = ir::Operation *op2 =
ir::Operation::create(op_inputs, ir::Operation::create(op_inputs,
op_output_types, op_output_types,
CreateAttributeMap({"op2_attr1", "op2_attr2"}, CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}), {"op2_attr1", "op2_attr2"}),
op2_info); op2_info);
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>(); ReadOnlyTrait trait = op2->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op); EXPECT_EQ(trait.operation(), op2);
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>(); InferShapeInterface interface = op2->dyn_cast<InferShapeInterface>();
interface.InferShape(); interface.InferShape();
Operation2 Op2 = op->dyn_cast<Operation2>(); Operation2 Op2 = op2->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op); EXPECT_EQ(Op2.operation(), op2);
op->destroy(); op2->destroy();
} }
...@@ -57,7 +57,7 @@ TEST(program_test, program) { ...@@ -57,7 +57,7 @@ TEST(program_test, program) {
// (2) Create an empty program object // (2) Create an empty program object
ir::Program program; ir::Program program;
// ir::Program *program = new ir::Program(); // ir::Program *program = new ir::Program();
EXPECT_EQ(program.ops().size() == 0, true); EXPECT_EQ(program.block()->size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program // (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
...@@ -207,8 +207,7 @@ TEST(program_test, program) { ...@@ -207,8 +207,7 @@ TEST(program_test, program) {
program.SetParameter("c", std::move(parameter_c)); program.SetParameter("c", std::move(parameter_c));
// (8) Traverse Program // (8) Traverse Program
std::list<ir::Operation *> ops = program.ops(); EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(ops.size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true); EXPECT_EQ(program.parameters_num() == 3, true);
} }
...@@ -220,7 +219,7 @@ TEST(program_test, slice_combine_test) { ...@@ -220,7 +219,7 @@ TEST(program_test, slice_combine_test) {
// (2) Create an empty program object // (2) Create an empty program object
ir::Program program; ir::Program program;
// ir::Program *program = new ir::Program(); // ir::Program *program = new ir::Program();
EXPECT_EQ(program.ops().size() == 0, true); EXPECT_EQ(program.block()->size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program // (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx); ir::Type fp32_dtype = ir::Float32Type::get(ctx);
...@@ -267,6 +266,5 @@ TEST(program_test, slice_combine_test) { ...@@ -267,6 +266,5 @@ TEST(program_test, slice_combine_test) {
program.InsertOp(slice_op); program.InsertOp(slice_op);
// (8) Traverse Program // (8) Traverse Program
std::list<ir::Operation *> ops = program.ops(); EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(ops.size() == 4, true);
} }
...@@ -43,7 +43,7 @@ TEST(value_test, value_test) { ...@@ -43,7 +43,7 @@ TEST(value_test, value_test) {
op1_output_types, op1_output_types,
CreateAttributeMap("op1_name", "op1_attr"), CreateAttributeMap("op1_name", "op1_attr"),
nullptr); nullptr);
std::cout << op1->print() << std::endl; VLOG(0) << op1->print();
// 2. Construct OP2: b = OP2(); // 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {}; std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)}; std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
...@@ -52,7 +52,7 @@ TEST(value_test, value_test) { ...@@ -52,7 +52,7 @@ TEST(value_test, value_test) {
op2_output_types, op2_output_types,
CreateAttributeMap("op2_name", "op2_attr"), CreateAttributeMap("op2_name", "op2_attr"),
nullptr); nullptr);
std::cout << op2->print() << std::endl; VLOG(0) << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b); // 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)}; op2->GetResultByIndex(0)};
...@@ -62,7 +62,7 @@ TEST(value_test, value_test) { ...@@ -62,7 +62,7 @@ TEST(value_test, value_test) {
op3_output_types, op3_output_types,
CreateAttributeMap("op3_name", "op3_attr"), CreateAttributeMap("op3_name", "op3_attr"),
nullptr); nullptr);
std::cout << op3->print() << std::endl; VLOG(0) << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c); // 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0), std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
op3->GetResultByIndex(0)}; op3->GetResultByIndex(0)};
...@@ -75,7 +75,7 @@ TEST(value_test, value_test) { ...@@ -75,7 +75,7 @@ TEST(value_test, value_test) {
op4_output_types, op4_output_types,
CreateAttributeMap("op4_name", "op4_attr"), CreateAttributeMap("op4_name", "op4_attr"),
nullptr); nullptr);
std::cout << op4->print() << std::endl; VLOG(0) << op4->print() << std::endl;
// Test 1: // Test 1:
EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1); EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1);
...@@ -103,12 +103,12 @@ TEST(value_test, value_test) { ...@@ -103,12 +103,12 @@ TEST(value_test, value_test) {
EXPECT_EQ(iter.owner(), op3); EXPECT_EQ(iter.owner(), op3);
// destroy // destroy
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy(); op4->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op3->destroy(); op3->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op2->destroy(); op2->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl; VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op1->destroy(); op1->destroy();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册