未验证 提交 3143d8bf 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine Builder (#54052)

* refine code

* delete some unused code

* refine code of build

* refine code of build

* add block

* refine builder

* refine code

* refine code by comment

* fix compiler bug
上级 1549fbd3
......@@ -25,40 +25,60 @@ namespace dialect {
class className : public ir::Op<className> { \
public: \
static const char *name() { return OPNAME(op_name); } \
static const char **attributes_name; \
static constexpr const char **attributes_name = nullptr; \
static constexpr uint32_t attributes_num = 0; \
static void verify(const std::vector<ir::OpResult> &inputs, \
const std::vector<ir::Type> &outputs, \
const ir::AttributeMap &attributes) { \
LOG(WARNING) << "This is a fake verify"; \
} \
}; \
const char **className::attributes_name = nullptr;
};
REIGSTER_EMPTY_OP(conv2d, Conv2DOp);
REIGSTER_EMPTY_OP(feed, FeedOp);
REIGSTER_EMPTY_OP(batch_norm, BatchNormOp);
REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_);
REIGSTER_EMPTY_OP(elementwise_add, ElementwiseAddOp);
REIGSTER_EMPTY_OP(pool2d, Pool2DOp);
REIGSTER_EMPTY_OP(flatten_contiguous_range, FlattenContiguousRangeOp);
REIGSTER_EMPTY_OP(matmul_v2, MatmulV2Op);
REIGSTER_EMPTY_OP(reshape2, Reshape2Op);
REIGSTER_EMPTY_OP(softmax_with_cross_entropy, SoftmaxWithCrossEntropyOp);
REIGSTER_EMPTY_OP(reduce_mean, ReduceMeanOp);
REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op);
REIGSTER_EMPTY_OP(fill_constant, FillConstantOp);
REIGSTER_EMPTY_OP(reduce_mean_grad, ReduceMeanGradOp);
REIGSTER_EMPTY_OP(softmax_with_cross_entropy_grad,
SoftmaxWithCrossEntropyGradOp);
REIGSTER_EMPTY_OP(elementwise_add_grad, ElementwiseAddGradOp);
REIGSTER_EMPTY_OP(matmul_v2_grad, MatmulV2GradOp);
REIGSTER_EMPTY_OP(flatten_contiguous_range_grad, FlattenContiguousRangeGradOp);
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp);
REIGSTER_EMPTY_OP(batch_norm_grad, BatchNormGradOp);
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp);
REIGSTER_EMPTY_OP(sum, SumOp);
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op);
// TODO(zhangbo): As operators are supplemented and defined, they are gradually
// removed.
REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d
REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed
REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm
REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_
REIGSTER_EMPTY_OP(elementwise_add,
ElementwiseAddOp); // To be customized: add (elementwise_add)
REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d
REIGSTER_EMPTY_OP(
flatten_contiguous_range,
FlattenContiguousRangeOp); // flatten (flatten_contiguous_range)
REIGSTER_EMPTY_OP(matmul_v2,
MatmulV2Op); // To be customized: matmul (matmul_v2)
REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape
REIGSTER_EMPTY_OP(softmax_with_cross_entropy,
SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax
// (softmax_with_cross_entropy)
REIGSTER_EMPTY_OP(reduce_mean,
ReduceMeanOp); // To be customized: mean (reduce_mean)
REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2)
REIGSTER_EMPTY_OP(fill_constant,
FillConstantOp); // To be customized: full (fill_constant)
REIGSTER_EMPTY_OP(reduce_mean_grad,
ReduceMeanGradOp); // To be customized: reduce_mean_grad
REIGSTER_EMPTY_OP(
softmax_with_cross_entropy_grad,
SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad
// (softmax_with_cross_entropy_grad)
REIGSTER_EMPTY_OP(
elementwise_add_grad,
ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad)
REIGSTER_EMPTY_OP(
matmul_v2_grad,
MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad)
REIGSTER_EMPTY_OP(
flatten_contiguous_range_grad,
FlattenContiguousRangeGradOp); // flatten_grad
// (flatten_contiguous_range_grad)
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad
REIGSTER_EMPTY_OP(batch_norm_grad,
BatchNormGradOp); // To be customized: batch_norm_grad
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad
REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2
} // namespace dialect
} // namespace paddle
......@@ -19,10 +19,6 @@
namespace paddle {
namespace dialect {
#define GET_PD_DIALECT_ATTRIBUTE_LIST \
IntArrayAttribute, ScalarAttribute, DataTypeAttribute, PlaceAttribute, \
DataLayoutAttribute
class IntArrayAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
......
......@@ -29,12 +29,12 @@
namespace paddle {
namespace dialect {
std::shared_ptr<paddle::framework::Variable>
ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) {
ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
if (parameter->type().isa<DenseTensorType>()) {
VLOG(4) << "Convert a DenseTensor Parameter to a variable.";
std::shared_ptr<paddle::framework::Variable> var =
std::make_shared<paddle::framework::Variable>();
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Init DenseTensor
auto dim = parameter->type().dyn_cast<DenseTensorType>().dim();
phi::DenseTensorMeta meta(
......@@ -46,7 +46,7 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) {
parameter->type().dyn_cast<DenseTensorType>().lod(),
parameter->type().dyn_cast<DenseTensorType>().offset());
tensor->set_meta(meta);
paddle::platform::DeviceContext* dev_ctx =
paddle::platform::DeviceContext *dev_ctx =
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace());
dev_ctx->Alloc(tensor,
......@@ -62,11 +62,11 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter* parameter) {
}
std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
paddle::framework::Variable* var) {
paddle::framework::Variable *var) {
if (var->IsType<phi::DenseTensor>()) {
phi::DenseTensor* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Get Meta
ir::IrContext* ctx = ir::IrContext::Instance();
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
DenseTensorTypeStorage::Dim dims(tensor->dims().size());
std::copy(tensor->dims().Get(),
......@@ -76,7 +76,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
TransToIrDataLayout(tensor->layout());
DenseTensorTypeStorage::LoD lod = tensor->lod();
size_t offset = tensor->meta().offset;
void* data = tensor->data();
void *data = tensor->data();
ir::Type dense_tensor_type =
DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset);
return std::make_unique<ir::Parameter>(
......@@ -88,7 +88,7 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
}
}
PaddleDialect::PaddleDialect(ir::IrContext* context)
PaddleDialect::PaddleDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<PaddleDialect>()) {
initialize();
}
......@@ -136,11 +136,11 @@ void PaddleDialect::initialize() {
FetchV2Op>();
}
void PaddleDialect::PrintType(ir::Type type, std::ostream& os) {
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<";
auto& dims = tensor_type.dim();
auto &dims = tensor_type.dim();
for (auto d : dims) {
os << d;
os << "x";
......
......@@ -19,8 +19,6 @@
namespace paddle {
namespace dialect {
#define GET_PD_DIALECT_TYPE_LIST paddle::dialect::DenseTensorType
///
/// \brief Define built-in parametric types.
///
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/block.h"
namespace ir {
Block::~Block() { clear(); }
void Block::clear() {
while (!empty()) {
ops_.back()->destroy();
ops_.pop_back();
}
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include "paddle/ir/operation.h"
namespace ir {
class Block {
public:
using iterator = std::list<Operation *>::iterator;
using reverse_iterator = std::list<Operation *>::reverse_iterator;
Block() = default;
~Block();
bool empty() const { return ops_.empty(); }
size_t size() const { return ops_.size(); }
iterator begin() { return ops_.begin(); }
iterator end() { return ops_.end(); }
reverse_iterator rbegin() { return ops_.rbegin(); }
reverse_iterator rend() { return ops_.rend(); }
Operation *back() { return ops_.back(); }
Operation *front() { return ops_.front(); }
void push_back(Operation *op) { ops_.push_back(op); }
void push_front(Operation *op) { ops_.push_front(op); }
std::list<Operation *>::iterator insert(
std::list<Operation *>::const_iterator iterator, Operation *op) {
return ops_.insert(iterator, op);
}
void clear();
private:
Block(Block &) = delete;
void operator=(Block &) = delete;
private:
std::list<Operation *> ops_; // owned
};
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/builder.h"
namespace ir {
Operation *Builder::insert(Operation *op) {
if (block_) {
block_->insert(insert_point_, op);
} else {
LOG(WARNING) << "Builder's Block is nullptr, insert failed.";
}
return op;
}
/// Create an operation given the fields represented as an OperationState.
Operation *Builder::create(const OperationArgument &argument) {
return insert(Operation::create(argument));
}
/// Creates an operation with the given fields.
Operation *Builder::create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info) {
OperationArgument argument(op_info, inputs, output_types, attribute);
return create(argument);
}
} // namespace ir
......@@ -16,7 +16,9 @@
#include <list>
#include "paddle/ir/block.h"
#include "paddle/ir/operation.h"
#include "paddle/ir/program.h"
namespace ir {
///
......@@ -25,25 +27,47 @@ namespace ir {
///
class Builder {
public:
explicit Builder(IrContext *context) : context_(context) {}
explicit Builder(Operation *op) : Builder(op->ir_context()) {}
explicit Builder(IrContext *context,
Block *block,
Block::iterator insert_point)
: context_(context), block_(block), insert_point_(insert_point) {}
static Builder AtBlockBegin(IrContext *context, Block *block) {
return Builder(context, block, block->begin());
}
static Builder AtBlockEnd(IrContext *context, Block *block) {
return Builder(context, block, block->end());
}
IrContext *context() const { return context_; }
Block *block() const { return block_; }
Operation *insert(Operation *op);
/// Creates an operation given the fields represented as an OperationState.
Operation *create(const OperationArgument &argument);
/// Creates an operation with the given fields.
Operation *create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
const AttributeMap &attribute,
ir::OpInfo op_info);
/// Create an operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpTy create(Args &&...args) {
OperationArgument argument(context_->GetRegisteredOpInfo(OpTy::name()));
OpTy::build(*this, argument, std::forward<Args>(args)...);
Operation *op = Operation::create(argument);
Operation *op = create(argument);
return op->dyn_cast<OpTy>();
}
private:
IrContext *context_;
// The current op list this builder is inserting into.
// After the design of the block data structure is completed,
// this member will be replaced by the block.
std::list<Operation *> *op_list_ = nullptr;
Block *block_ = nullptr;
// The insertion point within the list that this builder is inserting before.
std::list<Operation *>::iterator insertPoint;
Block::iterator insert_point_;
};
} // namespace ir
......@@ -19,13 +19,6 @@
#include "paddle/ir/utils.h"
namespace ir {
///
/// \brief All built-in attributes.
///
#define GET_BUILT_IN_ATTRIBUTE_LIST \
StrAttribute, BoolAttribute, FloatAttribute, DoubleAttribute, \
Int32_tAttribute, Int64_tAttribute, ArrayAttribute
class StrAttribute : public Attribute {
public:
using Attribute::Attribute;
......
......@@ -18,15 +18,6 @@
#include "paddle/ir/type.h"
namespace ir {
///
/// \brief This macro is used to get a list of all built-in types in this file.
/// The built-in Dialect will use this macro to quickly register all built-in
/// types.
///
#define GET_BUILT_IN_TYPE_LIST \
BFloat16Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, \
Int32Type, Int64Type, BoolType, VectorType
///
/// \brief Define built-in parameterless types. Please add the necessary
/// interface functions for built-in types through the macro
......
......@@ -66,18 +66,18 @@ class InterfaceValue {
class OpBase {
public:
explicit OpBase(const Operation *operation) : operation_(operation) {}
explicit OpBase(Operation *operation) : operation_(operation) {}
const Operation *operation() const { return operation_; }
Operation *operation() const { return operation_; }
explicit operator bool() const { return operation() != nullptr; }
operator const Operation *() const { return operation_; }
operator Operation *() const { return operation_; }
const Operation *operator->() const { return operation_; }
Operation *operator->() const { return operation_; }
private:
const Operation *operation_; // Not owned
Operation *operation_; // Not owned
};
///
......@@ -86,11 +86,11 @@ class OpBase {
template <class ConcreteTrait>
class OpTraitBase : public OpBase {
public:
explicit OpTraitBase(const Operation *op) : OpBase(op) {}
explicit OpTraitBase(Operation *op) : OpBase(op) {}
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(const Operation *op) {
static ConcreteTrait dyn_cast(Operation *op) {
if (op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op);
}
......@@ -104,13 +104,11 @@ class OpTraitBase : public OpBase {
template <typename ConcreteInterface>
class OpInterfaceBase : public OpBase {
public:
// explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(const Operation *op) {
static ConcreteInterface dyn_cast(Operation *op) {
if (op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
......@@ -183,7 +181,7 @@ class Op : public OpBase {
using InterfaceList =
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(const Operation *op) {
static ConcreteOp dyn_cast(Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op);
}
......
......@@ -14,6 +14,7 @@
#pragma once
#include <iostream>
#include "paddle/ir/op_info.h"
#include "paddle/ir/operation_utils.h"
#include "paddle/ir/type.h"
......@@ -61,7 +62,7 @@ class alignas(8) Operation final {
std::string op_name() const;
template <typename T>
T dyn_cast() const {
T dyn_cast() {
return CastUtil<T>::call(this);
}
......@@ -89,7 +90,7 @@ class alignas(8) Operation final {
template <typename T, typename Enabler = void>
struct CastUtil {
static T call(const Operation *op) {
static T call(Operation *op) {
throw("Can't dyn_cast to T, T should be a Op or Trait or Interface");
}
};
......@@ -98,7 +99,7 @@ class alignas(8) Operation final {
struct CastUtil<
T,
typename std::enable_if<std::is_base_of<OpBase, T>::value>::type> {
static T call(const Operation *op) { return T::dyn_cast(op); }
static T call(Operation *op) { return T::dyn_cast(op); }
};
AttributeMap attribute_;
......
......@@ -95,9 +95,11 @@ class ProgramPrinter : public Printer {
explicit ProgramPrinter(std::ostream& os) : Printer(os), cur_var_number(0) {}
void Print(ir::Program& program) {
for (auto* op : program.ops()) {
PrintOperation(op);
auto iterator = program.block()->begin();
while (iterator != program.block()->end()) {
PrintOperation(*iterator);
os << newline;
iterator++;
}
}
......
......@@ -16,14 +16,10 @@
#include "paddle/ir/ir_context.h"
namespace ir {
Program::~Program() {
for (auto op : ops_) {
op->destroy();
}
}
Program::~Program() = default;
void Program::InsertOp(Operation* op) {
ops_.push_back(op);
block_.push_back(op);
op->set_parent_program(this);
}
......
......@@ -17,6 +17,7 @@
#include <list>
#include <unordered_map>
#include "paddle/ir/block.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/operation.h"
#include "paddle/ir/parameter.h"
......@@ -34,7 +35,7 @@ class Program {
public:
~Program();
std::list<Operation*> ops() const { return ops_; }
Block* block() { return &block_; }
size_t parameters_num() const { return parameters_.size(); }
......@@ -51,8 +52,7 @@ class Program {
void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter);
private:
std::list<Operation*> ops_; // owned
Block block_;
std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_;
};
......
......@@ -14,6 +14,7 @@
#include <gtest/gtest.h>
#include "paddle/ir/builder.h"
#include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/dialect.h"
......@@ -23,7 +24,7 @@
/// \brief Define built-in Trait, derived from OpTraitBase.
class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
public:
explicit ReadOnlyTrait(const ir::Operation *op)
explicit ReadOnlyTrait(ir::Operation *op)
: ir::OpTraitBase<ReadOnlyTrait>(op) {}
};
......@@ -34,14 +35,14 @@ class ReadOnlyTrait : public ir::OpTraitBase<ReadOnlyTrait> {
class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
public:
struct Concept {
explicit Concept(void (*infer_shape)(const ir::Operation *))
explicit Concept(void (*infer_shape)(ir::Operation *))
: infer_shape_(infer_shape) {}
void (*infer_shape_)(const ir::Operation *);
void (*infer_shape_)(ir::Operation *);
};
template <class ConcreteOp>
struct Model : public Concept {
static void InferShape(const ir::Operation *op) {
static void InferShape(ir::Operation *op) {
ConcreteOp concret_op = ConcreteOp(op);
if (concret_op == nullptr) throw("concret_op is nullptr");
concret_op.InferShape();
......@@ -53,7 +54,7 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
}
};
InferShapeInterface(const ir::Operation *op, Concept *impl)
InferShapeInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<InferShapeInterface>(op), impl_(impl) {}
void InferShape() { impl_->infer_shape_(operation()); }
......@@ -62,6 +63,18 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
Concept *impl_;
};
ir::AttributeMap CreateAttributeMap(std::vector<std::string> attribute_names,
std::vector<std::string> attributes) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::AttributeMap attr_map;
for (size_t i = 0; i < attribute_names.size(); i++) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]);
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_names[i], attr_value));
}
return attr_map;
}
// Define op1.
class Operation1 : public ir::Op<Operation1> {
public:
......@@ -81,6 +94,22 @@ class Operation1 : public ir::Op<Operation1> {
throw("Type of attribute: parameter_name is not right.");
}
}
static void build(const ir::Builder &builder,
ir::OperationArgument &argument) { // NOLINT
std::vector<ir::OpResult> inputs = {};
std::vector<ir::Type> output_types = {
ir::Float32Type::get(builder.context())};
std::unordered_map<std::string, ir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"});
argument.addOperands<std::vector<ir::OpResult>::iterator>(inputs.begin(),
inputs.end());
argument.addTypes<std::vector<ir::Type>::iterator>(output_types.begin(),
output_types.end());
argument.addAttributes<
std::unordered_map<std::string, ir::Attribute>::iterator>(
attributes.begin(), attributes.end());
}
};
const char *Operation1::attributes_name[attributes_num] = {"op1_attr1",
"op1_attr2"};
......@@ -105,9 +134,7 @@ class Operation2
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() {
std::cout << "This is op2's InferShape interface." << std::endl;
}
static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; }
};
const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
......@@ -125,23 +152,11 @@ class TestDialect : public ir::Dialect {
void initialize() { RegisterOps<Operation1, Operation2>(); }
};
ir::AttributeMap CreateAttributeMap(std::vector<std::string> attribute_names,
std::vector<std::string> attributes) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::AttributeMap attr_map;
for (size_t i = 0; i < attribute_names.size(); i++) {
ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]);
attr_map.insert(
std::pair<std::string, ir::Attribute>(attribute_names[i], attr_value));
}
return attr_map;
}
TEST(op_test, op_test) {
// (1) Register Dialect, Operation1, Operation2 into IrContext.
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Dialect *test_dialect = ctx->GetOrRegisterDialect<TestDialect>();
std::cout << test_dialect << std::endl;
EXPECT_EQ(test_dialect != nullptr, true);
// (2) Get registered operations.
std::string op1_name = Operation1::name();
......@@ -158,18 +173,18 @@ TEST(op_test, op_test) {
// (3) Test uses for op.
std::vector<ir::OpResult> op_inputs = {};
std::vector<ir::Type> op_output_types = {ir::Float32Type::get(ctx)};
ir::Operation *op =
ir::Operation *op2 =
ir::Operation::create(op_inputs,
op_output_types,
CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"}),
op2_info);
ReadOnlyTrait trait = op->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op);
InferShapeInterface interface = op->dyn_cast<InferShapeInterface>();
ReadOnlyTrait trait = op2->dyn_cast<ReadOnlyTrait>();
EXPECT_EQ(trait.operation(), op2);
InferShapeInterface interface = op2->dyn_cast<InferShapeInterface>();
interface.InferShape();
Operation2 Op2 = op->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op);
op->destroy();
Operation2 Op2 = op2->dyn_cast<Operation2>();
EXPECT_EQ(Op2.operation(), op2);
op2->destroy();
}
......@@ -57,7 +57,7 @@ TEST(program_test, program) {
// (2) Create an empty program object
ir::Program program;
// ir::Program *program = new ir::Program();
EXPECT_EQ(program.ops().size() == 0, true);
EXPECT_EQ(program.block()->size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
......@@ -207,8 +207,7 @@ TEST(program_test, program) {
program.SetParameter("c", std::move(parameter_c));
// (8) Traverse Program
std::list<ir::Operation *> ops = program.ops();
EXPECT_EQ(ops.size() == 4, true);
EXPECT_EQ(program.block()->size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true);
}
......@@ -220,7 +219,7 @@ TEST(program_test, slice_combine_test) {
// (2) Create an empty program object
ir::Program program;
// ir::Program *program = new ir::Program();
EXPECT_EQ(program.ops().size() == 0, true);
EXPECT_EQ(program.block()->size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
......@@ -267,6 +266,5 @@ TEST(program_test, slice_combine_test) {
program.InsertOp(slice_op);
// (8) Traverse Program
std::list<ir::Operation *> ops = program.ops();
EXPECT_EQ(ops.size() == 4, true);
EXPECT_EQ(program.block()->size() == 4, true);
}
......@@ -43,7 +43,7 @@ TEST(value_test, value_test) {
op1_output_types,
CreateAttributeMap("op1_name", "op1_attr"),
nullptr);
std::cout << op1->print() << std::endl;
VLOG(0) << op1->print();
// 2. Construct OP2: b = OP2();
std::vector<ir::OpResult> op2_inputs = {};
std::vector<ir::Type> op2_output_types = {ir::Float32Type::get(ctx)};
......@@ -52,7 +52,7 @@ TEST(value_test, value_test) {
op2_output_types,
CreateAttributeMap("op2_name", "op2_attr"),
nullptr);
std::cout << op2->print() << std::endl;
VLOG(0) << op2->print() << std::endl;
// 3. Construct OP3: c = OP3(a, b);
std::vector<ir::OpResult> op3_inputs = {op1->GetResultByIndex(0),
op2->GetResultByIndex(0)};
......@@ -62,7 +62,7 @@ TEST(value_test, value_test) {
op3_output_types,
CreateAttributeMap("op3_name", "op3_attr"),
nullptr);
std::cout << op3->print() << std::endl;
VLOG(0) << op3->print() << std::endl;
// 4. Construct OP4: d, e, f, g, h, i, j = OP4(a, c);
std::vector<ir::OpResult> op4_inputs = {op1->GetResultByIndex(0),
op3->GetResultByIndex(0)};
......@@ -75,7 +75,7 @@ TEST(value_test, value_test) {
op4_output_types,
CreateAttributeMap("op4_name", "op4_attr"),
nullptr);
std::cout << op4->print() << std::endl;
VLOG(0) << op4->print() << std::endl;
// Test 1:
EXPECT_EQ(op1->GetResultByIndex(0).GetDefiningOp(), op1);
......@@ -103,12 +103,12 @@ TEST(value_test, value_test) {
EXPECT_EQ(iter.owner(), op3);
// destroy
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op4->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op3->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op2->destroy();
std::cout << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
VLOG(0) << op1->GetResultByIndex(0).print_ud_chain() << std::endl;
op1->destroy();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册