未验证 提交 49bedfd3 编写于 作者: W winter-wang 提交者: GitHub

[IR] refine the program data structure. (#54220)

上级 4bd5b695
......@@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
op_attribute_map,
{src_vec_type[defining_info.idx_in_vector]},
op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation;
......@@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation =
ir::Operation::create(src_values, {}, {target_vec_type}, op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
return operation;
}
......@@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation;
......@@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
RecordOpResultMapping(param_map, op_desc, operation, arg_to_idx);
return operation;
......@@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
auto op_info = LoopkUpOpInfo(ctx, op_desc);
ir::Operation* operation =
ir::Operation::create(op_inputs, {}, op_output_types, op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
return operation;
}
......
......@@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create(
{}, op_attribute_map, {translated_var_type}, op_info);
program->InsertOp(operation);
program->block()->push_back(operation);
param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << operation;
......
......@@ -46,6 +46,12 @@ class Block {
iterator insert(const_iterator iterator, Operation *op);
void clear();
Region *GetParentRegion() const { return parent_; }
Operation *GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
private:
Block(Block &) = delete;
Block &operator=(const Block &) = delete;
......
......@@ -33,4 +33,6 @@ std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey();
}
void* PointerAttribute::data() const { return storage()->GetAsKey(); }
} // namespace ir
......@@ -25,7 +25,7 @@ class StrAttribute : public Attribute {
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
bool operator<(const StrAttribute &right) const {
bool operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
......@@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute {
Attribute operator[](size_t index) const { return data()[index]; }
};
class PointerAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage);
void* data() const;
};
} // namespace ir
......@@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32_tAttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64_tAttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage {
using ParamKey = std::vector<Attribute>;
......
......@@ -40,11 +40,13 @@ void BuiltinDialect::initialize() {
ir::BoolAttribute,
ir::FloatAttribute,
ir::DoubleAttribute,
ir::PointerAttribute,
ir::Int32_tAttribute,
ir::Int64_tAttribute,
ir::ArrayAttribute>();
RegisterOps<ir::GetParameterOp,
RegisterOps<ir::ModuleOp,
ir::GetParameterOp,
ir::SetParameterOp,
ir::CombineOp,
ir::SliceOp>();
......
......@@ -19,6 +19,60 @@
#include "paddle/phi/core/enforce.h"
namespace ir {
const char *ModuleOp::attributes_name[attributes_num] = {"program"};
Program *ModuleOp::program() {
const AttributeMap &attr = operation()->attributes();
auto iter = attr.find("program");
if (iter == attr.end() || !iter->second) return nullptr;
return static_cast<Program *>(
iter->second.dyn_cast<PointerAttribute>().data());
}
Block *ModuleOp::block() {
assert(operation() != nullptr);
assert(operation()->num_regions() == 1);
assert(operation()->GetRegion(0).size() == 1);
return operation()->GetRegion(0).front();
}
ModuleOp ModuleOp::create(IrContext *context, Program *pointer) {
ir::OpInfo info = context->GetRegisteredOpInfo(name());
OperationArgument argument(info);
argument.AddRegion()->emplace_back();
argument.addAttribute("program", PointerAttribute::get(context, pointer));
return ModuleOp(Operation::create(std::move(argument)));
}
void ModuleOp::destroy() {
if (operation()) {
operation()->destroy();
*this = ModuleOp(nullptr);
}
}
void ModuleOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
if (inputs.size() != 0) {
throw("The size of inputs must be equal to 0.");
}
// Verify if attributes contain attribute name in attributes_name:
auto iter = attributes.find("program");
if (iter == attributes.end() || !iter->second.isa<PointerAttribute>()) {
throw("Type of attribute: program is not right.");
}
// Verify outputs type:
if (outputs.size() != 0) {
throw("The size of outputs must be equal to 0.");
}
}
const char *GetParameterOp::attributes_name[attributes_num] = {
"parameter_name"};
......
......@@ -18,6 +18,31 @@
namespace ir {
class Program;
class Block;
///
/// \brief ModuleOp
///
class ModuleOp : public ir::Op<ModuleOp> {
public:
using Op::Op;
static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Program *program();
Block *block();
//
// As the top operation, ModuleOp only support create&destroye through
// below interface: "create"&"destroy".
static ModuleOp create(IrContext *context, Program *pointer);
void destroy();
};
///
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// StrAttribute})
......
......@@ -66,7 +66,7 @@ class InterfaceValue {
class OpBase {
public:
explicit OpBase(Operation *operation) : operation_(operation) {}
explicit OpBase(Operation *operation = nullptr) : operation_(operation) {}
Operation *operation() const { return operation_; }
......@@ -76,6 +76,8 @@ class OpBase {
Operation *operator->() const { return operation_; }
IrContext *ir_context() const { return operation_->ir_context(); }
private:
Operation *operation_; // Not owned
};
......@@ -91,7 +93,7 @@ class OpTraitBase : public OpBase {
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
static ConcreteTrait dyn_cast(Operation *op) {
if (op->HasTrait<ConcreteTrait>()) {
if (op && op->HasTrait<ConcreteTrait>()) {
return ConcreteTrait(op);
}
return ConcreteTrait(nullptr);
......@@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase {
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
static ConcreteInterface dyn_cast(Operation *op) {
if (op->HasInterface<ConcreteInterface>()) {
if (op && op->HasInterface<ConcreteInterface>()) {
return ConcreteInterface(
op, op->op_info().GetInterfaceImpl<ConcreteInterface>());
}
......@@ -182,7 +184,7 @@ class Op : public OpBase {
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
static ConcreteOp dyn_cast(Operation *op) {
if (op->op_info().id() == TypeId::get<ConcreteOp>()) {
if (op && op->op_info().id() == TypeId::get<ConcreteOp>()) {
return ConcreteOp(op);
}
return ConcreteOp(nullptr);
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
......@@ -21,7 +22,7 @@
namespace ir {
Operation *Operation::create(OperationArgument &&argument) {
Operation *op = create(argument.inputs,
argument.attribute,
argument.attributes,
argument.output_types,
argument.info,
argument.regions.size());
......@@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) {
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attribute,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions) {
// 0. Verify
if (op_info) {
op_info.verify(inputs, output_types, attribute);
op_info.verify(inputs, output_types, attributes);
}
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
......@@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
}
// 3.2. Construct Operation.
Operation *op = new (base_ptr)
Operation(attribute, op_info, num_results, num_operands, num_regions);
Operation(attributes, op_info, num_results, num_operands, num_regions);
base_ptr += sizeof(Operation);
// 3.3. Construct OpOperands.
if ((reinterpret_cast<uintptr_t>(base_ptr) & 0x7) != 0) {
......@@ -160,12 +161,12 @@ void Operation::destroy() {
IrContext *Operation::ir_context() const { return op_info_.ir_context(); }
Operation::Operation(const AttributeMap &attribute,
Operation::Operation(const AttributeMap &attributes,
ir::OpInfo op_info,
uint32_t num_results,
uint32_t num_operands,
uint32_t num_regions)
: attribute_(attribute),
: attributes_(attributes),
op_info_(op_info),
num_results_(num_results),
num_operands_(num_operands),
......@@ -223,6 +224,23 @@ std::string Operation::print() {
std::string Operation::op_name() const { return op_info_.name(); }
Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParentRegion() : nullptr;
}
Operation *Operation::GetParentOp() const {
return parent_ ? parent_->GetParentOp() : nullptr;
}
Program *Operation::GetParentProgram() {
Operation *op = this;
while (Operation *parent_op = op->GetParentOp()) {
op = parent_op;
}
ModuleOp module_op = op->dyn_cast<ModuleOp>();
return module_op ? module_op.program() : nullptr;
}
Region &Operation::GetRegion(unsigned index) {
assert(index < num_regions_ && "invalid region index");
return regions_[index];
......
......@@ -34,7 +34,7 @@ class alignas(8) Operation final {
/// used in conjunction.
///
static Operation *create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attribute,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions = 0);
......@@ -45,8 +45,6 @@ class alignas(8) Operation final {
///
void destroy();
Block *parent() const { return parent_; }
IrContext *ir_context() const;
ir::OpResult GetResultByIndex(uint32_t index) const;
......@@ -55,7 +53,11 @@ class alignas(8) Operation final {
std::string print();
const AttributeMap &attribute() const { return attribute_; }
const AttributeMap &attributes() const { return attributes_; }
void SetAttribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}
ir::OpInfo op_info() const { return op_info_; }
......@@ -82,11 +84,13 @@ class alignas(8) Operation final {
return op_info_.HasInterface<Interface>();
}
Program *parent_program() const { return parent_program_; }
Block *GetParentBlock() const { return parent_; }
void set_parent_program(Program *parent_program) {
parent_program_ = parent_program;
}
Region *GetParentRegion() const;
Operation *GetParentOp() const;
Program *GetParentProgram();
/// Returns the region held by this operation at position 'index'.
Region &GetRegion(unsigned index);
......@@ -115,7 +119,7 @@ class alignas(8) Operation final {
static T call(Operation *op) { return T::dyn_cast(op); }
};
AttributeMap attribute_;
AttributeMap attributes_;
OpInfo op_info_;
......@@ -124,7 +128,6 @@ class alignas(8) Operation final {
const uint32_t num_regions_ = 0;
Region *regions_{nullptr};
Program *parent_program_{nullptr};
Block *parent_{nullptr};
};
......
......@@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map<std::string, Attribute>;
// with the builder APIs.
struct OperationArgument {
std::vector<OpResult> inputs;
AttributeMap attribute;
AttributeMap attributes;
std::vector<Type> output_types;
OpInfo info;
std::vector<std::unique_ptr<Region>> regions;
......@@ -41,12 +41,12 @@ struct OperationArgument {
OperationArgument(IrContext* ir_context, const std::string& name);
explicit OperationArgument(OpInfo info) : info(info) {}
OperationArgument(const std::vector<OpResult>& operands,
const AttributeMap& named_attr,
const AttributeMap& attributes,
const std::vector<Type>& types,
OpInfo info,
std::vector<std::unique_ptr<Region>>&& regions = {})
: inputs(operands),
attribute(named_attr),
attributes(attributes),
output_types(types),
info(info),
regions(std::move(regions)) {}
......@@ -59,13 +59,18 @@ struct OperationArgument {
/// Add an attribute with the specified name.
void addAttribute(const std::string& name, Attribute attr) {
this->attribute[name] = attr;
attributes[name] = attr;
}
/// Add an array of named attributes.
template <class InputIt>
void addAttributes(InputIt first, InputIt last);
/// Get the context held by this operation state.
IrContext* getContext() const { return info.ir_context(); }
Region* AddRegion() {
regions.emplace_back(new Region);
return regions.back().get();
}
};
template <class InputIt>
......@@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) {
template <class InputIt>
void OperationArgument::addAttributes(InputIt first, InputIt last) {
while (first != last) {
attribute[first->first] = first->second;
attributes[first->first] = first->second;
++first;
}
}
......
......@@ -16,11 +16,17 @@
#include "paddle/ir/core/ir_context.h"
namespace ir {
Program::~Program() = default;
void Program::InsertOp(Operation* op) {
block_.push_back(op);
op->set_parent_program(this);
Program::Program(IrContext* context) {
module_ = ModuleOp::create(context, this);
}
Program::Program() : Program(IrContext::Instance()) {}
Program::~Program() {
if (module_) {
module_.destroy();
}
}
Parameter* Program::GetParameter(std::string name) const {
......
......@@ -19,10 +19,13 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
namespace ir {
class IrContext;
///
/// \brief Program is an abstraction of model structure, divided into
/// computational graphs and weights. At the current stage, a computational
......@@ -33,27 +36,34 @@ namespace ir {
///
class Program {
public:
using ParameterMap =
std::unordered_map<std::string, std::unique_ptr<Parameter>>;
explicit Program(IrContext* context);
Program();
Program(Program&&) = delete;
Program(const Program& program) = delete;
Program& operator=(const Program&) = delete;
Program& operator=(Program&&);
~Program();
Block* block() { return &block_; }
size_t parameters_num() const { return parameters_.size(); }
///
/// \brief Insert the Operation* constructed by Operation::create(...) into
/// this Program. NOTE: At this time, the memory management permission of
/// Operation* will be owned by this Program. The user does not need to call
/// Operation::destroy() manually
///
void InsertOp(Operation* op);
ModuleOp module_op() { return module_; }
Parameter* GetParameter(std::string name) const;
Block* block() { return module_.block(); }
Parameter* GetParameter(std::string name) const;
void SetParameter(std::string name, std::unique_ptr<Parameter>&& parameter);
ParameterMap& parameters() { return parameters_; }
void set_parameters(ParameterMap&& parameters) {
parameters_ = std::move(parameters);
}
private:
Block block_;
std::unordered_map<std::string, std::unique_ptr<Parameter>> parameters_;
// computation graph
ModuleOp module_;
// weight
ParameterMap parameters_;
};
std::ostream& operator<<(std::ostream& os, Program& program);
......
......@@ -22,6 +22,9 @@ void Region::push_back(Block *block) {
block->set_parent(this);
blocks_.push_back(block);
}
void Region::emplace_back() { push_back(new Block); }
void Region::push_front(Block *block) {
block->set_parent(this);
blocks_.push_front(block);
......
......@@ -41,12 +41,15 @@ class Region {
Block *back() const { return blocks_.back(); }
Block *front() const { return blocks_.front(); }
void push_back(Block *block);
void emplace_back();
void push_front(Block *block);
iterator insert(const_iterator position, Block *block);
void clear();
void TakeBody(Region &&other);
Operation *GetParentOp() const { return parent_; }
private:
Region(Region &) = delete;
Region &operator=(const Region &) = delete;
......
......@@ -17,10 +17,12 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
/// \brief Define built-in Trait, derived from OpTraitBase.
......@@ -133,7 +135,7 @@ class Operation2
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; }
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
const char *Operation2::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
......@@ -212,8 +214,8 @@ TEST(op_test, region_test) {
op1_info);
ir::OperationArgument argument(op2_info);
argument.attribute = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"});
argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"},
{"op2_attr1", "op2_attr2"});
argument.output_types = {ir::Float32Type::get(ctx)};
argument.regions.emplace_back(std::make_unique<ir::Region>());
ir::Region *region = argument.regions.back().get();
......@@ -228,3 +230,26 @@ TEST(op_test, region_test) {
ir::Operation *op2 = ir::Operation::create(std::move(argument));
op2->destroy();
}
TEST(op_test, module_op_death) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(ir::ModuleOp::name());
// (3) Test uses for op.
std::vector<ir::OpResult> inputs{ir::OpResult()};
ir::AttributeMap attrs{{"program", ir::Int32_tAttribute::get(ctx, 1)}};
std::vector<ir::Type> output_types = {ir::Float32Type::get(ctx)};
EXPECT_THROW(ir::Operation::create(inputs, {}, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, attrs, {}, op_info), const char *);
EXPECT_THROW(ir::Operation::create({}, {}, output_types, op_info),
const char *);
ir::Program program(ctx);
EXPECT_EQ(program.module_op().program(), &program);
EXPECT_EQ(program.module_op().ir_context(), ctx);
program.module_op()->SetAttribute("program",
ir::PointerAttribute::get(ctx, &program));
}
......@@ -18,6 +18,7 @@
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
......@@ -56,9 +57,7 @@ TEST(program_test, program) {
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
// (2) Create an empty program object
ir::Program program;
// ir::Program *program = new ir::Program();
EXPECT_EQ(program.block()->size() == 0, true);
ir::Program program(ctx);
// (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
......@@ -94,7 +93,14 @@ TEST(program_test, program) {
ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {dense_tensor_dtype}, op1_info);
program.InsertOp(op1);
ir::Block *block = program.block();
block->push_back(op1);
EXPECT_EQ(&program.module_op()->GetRegion(0), block->GetParentRegion());
EXPECT_EQ(program.module_op(), block->GetParentOp());
EXPECT_EQ(&program, op1->GetParentProgram());
EXPECT_EQ(op1->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
......@@ -124,7 +130,7 @@ TEST(program_test, program) {
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {dense_tensor_dtype}, op2_info);
program.InsertOp(op2);
block->push_back(op2);
EXPECT_EQ(op2->GetResultByIndex(0).type().dialect().id(),
paddle_dialect->id());
......@@ -155,7 +161,7 @@ TEST(program_test, program) {
op3_attribute,
{dense_tensor_dtype},
op3_info);
program.InsertOp(op3);
block->push_back(op3);
phi::CPUContext *dev_ctx = static_cast<phi::CPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(
......@@ -196,9 +202,12 @@ TEST(program_test, program) {
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
std::unordered_map<std::string, ir::Attribute> op4_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "c")}};
ir::Operation *op4 = ir::Operation::create(
{op3->GetResultByIndex(0)}, op4_attribute, {}, op4_info);
program.InsertOp(op4);
ir::OperationArgument op4_argument(
{op3->GetResultByIndex(0)}, {}, {}, op4_info);
op4_argument.addAttributes(op4_attribute.begin(), op4_attribute.end());
ir::Operation *op4 = ir::Operation::create(std::move(op4_argument));
block->push_back(op4);
EXPECT_EQ(op4->GetOperandByIndex(0).impl()->source().type().dialect().id(),
paddle_dialect->id());
......@@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) {
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info);
program.InsertOp(op1);
program.block()->push_back(op1);
// (5) Def b = GetParameterOp("b")
std::string op2_name = std::string(ir::GetParameterOp::name());
......@@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) {
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info);
program.InsertOp(op2);
program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b")
std::string combine_op_name = std::string(ir::CombineOp::name());
......@@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) {
{},
{output_type},
combine_op_info);
program.InsertOp(combine_op);
program.block()->push_back(combine_op);
// (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name());
......@@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) {
{{"index", index_attr}},
{fp32_dtype},
slice_op_info);
program.InsertOp(slice_op);
program.block()->push_back(slice_op);
// (8) Traverse Program
EXPECT_EQ(program.block()->size() == 4, true);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册