Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
49bedfd3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
49bedfd3
编写于
6月 02, 2023
作者:
W
winter-wang
提交者:
GitHub
6月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR] refine the program data structure. (#54220)
上级
4bd5b695
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
249 addition
and
66 deletion
+249
-66
paddle/fluid/translator/op_translator.cc
paddle/fluid/translator/op_translator.cc
+5
-5
paddle/fluid/translator/program_translator.cc
paddle/fluid/translator/program_translator.cc
+1
-1
paddle/ir/core/block.h
paddle/ir/core/block.h
+6
-0
paddle/ir/core/builtin_attribute.cc
paddle/ir/core/builtin_attribute.cc
+2
-0
paddle/ir/core/builtin_attribute.h
paddle/ir/core/builtin_attribute.h
+10
-1
paddle/ir/core/builtin_attribute_storage.h
paddle/ir/core/builtin_attribute_storage.h
+1
-0
paddle/ir/core/builtin_dialect.cc
paddle/ir/core/builtin_dialect.cc
+3
-1
paddle/ir/core/builtin_op.cc
paddle/ir/core/builtin_op.cc
+54
-0
paddle/ir/core/builtin_op.h
paddle/ir/core/builtin_op.h
+25
-0
paddle/ir/core/op_base.h
paddle/ir/core/op_base.h
+6
-4
paddle/ir/core/operation.cc
paddle/ir/core/operation.cc
+24
-6
paddle/ir/core/operation.h
paddle/ir/core/operation.h
+13
-10
paddle/ir/core/operation_utils.h
paddle/ir/core/operation_utils.h
+10
-5
paddle/ir/core/program.cc
paddle/ir/core/program.cc
+10
-4
paddle/ir/core/program.h
paddle/ir/core/program.h
+23
-13
paddle/ir/core/region.cc
paddle/ir/core/region.cc
+3
-0
paddle/ir/core/region.h
paddle/ir/core/region.h
+3
-0
test/cpp/ir/core/ir_op_test.cc
test/cpp/ir/core/ir_op_test.cc
+28
-3
test/cpp/ir/core/ir_program_test.cc
test/cpp/ir/core/ir_program_test.cc
+22
-13
未找到文件。
paddle/fluid/translator/op_translator.cc
浏览文件 @
49bedfd3
...
@@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
...
@@ -113,7 +113,7 @@ inline ir::Operation* InsertSliceOperationForTarget(
op_attribute_map
,
op_attribute_map
,
{
src_vec_type
[
defining_info
.
idx_in_vector
]},
{
src_vec_type
[
defining_info
.
idx_in_vector
]},
op_info
);
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
ir
::
OpResult
target_op_result
=
operation
->
GetResultByIndex
(
0
);
ir
::
OpResult
target_op_result
=
operation
->
GetResultByIndex
(
0
);
(
*
param_map
)[
arg_name
]
=
VariableDefiningInfo
(
target_op_result
);
(
*
param_map
)[
arg_name
]
=
VariableDefiningInfo
(
target_op_result
);
return
operation
;
return
operation
;
...
@@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
...
@@ -137,7 +137,7 @@ inline ir::Operation* InsertCombineOperationForTarget(
ir
::
Type
target_vec_type
=
ir
::
VectorType
::
get
(
ctx
,
types_in_vec
);
ir
::
Type
target_vec_type
=
ir
::
VectorType
::
get
(
ctx
,
types_in_vec
);
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
src_values
,
{},
{
target_vec_type
},
op_info
);
ir
::
Operation
::
create
(
src_values
,
{},
{
target_vec_type
},
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
return
operation
;
}
}
...
@@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
...
@@ -282,7 +282,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
return
operation
;
return
operation
;
...
@@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
...
@@ -300,7 +300,7 @@ ir::Operation* FeedOpHandler(ir::IrContext* ctx,
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
RecordOpResultMapping
(
param_map
,
op_desc
,
operation
,
arg_to_idx
);
return
operation
;
return
operation
;
...
@@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
...
@@ -316,7 +316,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
auto
op_info
=
LoopkUpOpInfo
(
ctx
,
op_desc
);
ir
::
Operation
*
operation
=
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
ir
::
Operation
::
create
(
op_inputs
,
{},
op_output_types
,
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
return
operation
;
return
operation
;
}
}
...
...
paddle/fluid/translator/program_translator.cc
浏览文件 @
49bedfd3
...
@@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
...
@@ -80,7 +80,7 @@ void ProgramTranslator::ExtractParameterFromSingleBlock(
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Type
translated_var_type
=
type_translator
[
var
->
GetType
()](
ctx
,
*
var
);
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
ir
::
Operation
*
operation
=
ir
::
Operation
::
create
(
{},
op_attribute_map
,
{
translated_var_type
},
op_info
);
{},
op_attribute_map
,
{
translated_var_type
},
op_info
);
program
->
InsertOp
(
operation
);
program
->
block
()
->
push_back
(
operation
);
param_map
[
var
->
Name
()]
=
param_map
[
var
->
Name
()]
=
VariableDefiningInfo
(
operation
->
GetResultByIndex
(
0
));
VariableDefiningInfo
(
operation
->
GetResultByIndex
(
0
));
VLOG
(
10
)
<<
"[op translated][get parameter]"
<<
operation
;
VLOG
(
10
)
<<
"[op translated][get parameter]"
<<
operation
;
...
...
paddle/ir/core/block.h
浏览文件 @
49bedfd3
...
@@ -46,6 +46,12 @@ class Block {
...
@@ -46,6 +46,12 @@ class Block {
iterator
insert
(
const_iterator
iterator
,
Operation
*
op
);
iterator
insert
(
const_iterator
iterator
,
Operation
*
op
);
void
clear
();
void
clear
();
Region
*
GetParentRegion
()
const
{
return
parent_
;
}
Operation
*
GetParentOp
()
const
{
return
parent_
?
parent_
->
GetParentOp
()
:
nullptr
;
}
private:
private:
Block
(
Block
&
)
=
delete
;
Block
(
Block
&
)
=
delete
;
Block
&
operator
=
(
const
Block
&
)
=
delete
;
Block
&
operator
=
(
const
Block
&
)
=
delete
;
...
...
paddle/ir/core/builtin_attribute.cc
浏览文件 @
49bedfd3
...
@@ -33,4 +33,6 @@ std::vector<Attribute> ArrayAttribute::data() const {
...
@@ -33,4 +33,6 @@ std::vector<Attribute> ArrayAttribute::data() const {
return
storage
()
->
GetAsKey
();
return
storage
()
->
GetAsKey
();
}
}
void
*
PointerAttribute
::
data
()
const
{
return
storage
()
->
GetAsKey
();
}
}
// namespace ir
}
// namespace ir
paddle/ir/core/builtin_attribute.h
浏览文件 @
49bedfd3
...
@@ -25,7 +25,7 @@ class StrAttribute : public Attribute {
...
@@ -25,7 +25,7 @@ class StrAttribute : public Attribute {
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR
(
StrAttribute
,
StrAttributeStorage
);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR
(
StrAttribute
,
StrAttributeStorage
);
bool
operator
<
(
const
StrAttribute
&
right
)
const
{
bool
operator
<
(
const
StrAttribute
&
right
)
const
{
return
storage
()
<
right
.
storage
();
return
storage
()
<
right
.
storage
();
}
}
...
@@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute {
...
@@ -94,4 +94,13 @@ class ArrayAttribute : public Attribute {
Attribute
operator
[](
size_t
index
)
const
{
return
data
()[
index
];
}
Attribute
operator
[](
size_t
index
)
const
{
return
data
()[
index
];
}
};
};
class
PointerAttribute
:
public
Attribute
{
public:
using
Attribute
::
Attribute
;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR
(
PointerAttribute
,
PointerAttributeStorage
);
void
*
data
()
const
;
};
}
// namespace ir
}
// namespace ir
paddle/ir/core/builtin_attribute_storage.h
浏览文件 @
49bedfd3
...
@@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
...
@@ -83,6 +83,7 @@ DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
DoubleAttributeStorage
,
double
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
DoubleAttributeStorage
,
double
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
Int32_tAttributeStorage
,
int32_t
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
Int32_tAttributeStorage
,
int32_t
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
Int64_tAttributeStorage
,
int64_t
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
Int64_tAttributeStorage
,
int64_t
);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE
(
PointerAttributeStorage
,
void
*
);
struct
ArrayAttributeStorage
:
public
AttributeStorage
{
struct
ArrayAttributeStorage
:
public
AttributeStorage
{
using
ParamKey
=
std
::
vector
<
Attribute
>
;
using
ParamKey
=
std
::
vector
<
Attribute
>
;
...
...
paddle/ir/core/builtin_dialect.cc
浏览文件 @
49bedfd3
...
@@ -40,11 +40,13 @@ void BuiltinDialect::initialize() {
...
@@ -40,11 +40,13 @@ void BuiltinDialect::initialize() {
ir
::
BoolAttribute
,
ir
::
BoolAttribute
,
ir
::
FloatAttribute
,
ir
::
FloatAttribute
,
ir
::
DoubleAttribute
,
ir
::
DoubleAttribute
,
ir
::
PointerAttribute
,
ir
::
Int32_tAttribute
,
ir
::
Int32_tAttribute
,
ir
::
Int64_tAttribute
,
ir
::
Int64_tAttribute
,
ir
::
ArrayAttribute
>
();
ir
::
ArrayAttribute
>
();
RegisterOps
<
ir
::
GetParameterOp
,
RegisterOps
<
ir
::
ModuleOp
,
ir
::
GetParameterOp
,
ir
::
SetParameterOp
,
ir
::
SetParameterOp
,
ir
::
CombineOp
,
ir
::
CombineOp
,
ir
::
SliceOp
>
();
ir
::
SliceOp
>
();
...
...
paddle/ir/core/builtin_op.cc
浏览文件 @
49bedfd3
...
@@ -19,6 +19,60 @@
...
@@ -19,6 +19,60 @@
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
namespace
ir
{
namespace
ir
{
const
char
*
ModuleOp
::
attributes_name
[
attributes_num
]
=
{
"program"
};
Program
*
ModuleOp
::
program
()
{
const
AttributeMap
&
attr
=
operation
()
->
attributes
();
auto
iter
=
attr
.
find
(
"program"
);
if
(
iter
==
attr
.
end
()
||
!
iter
->
second
)
return
nullptr
;
return
static_cast
<
Program
*>
(
iter
->
second
.
dyn_cast
<
PointerAttribute
>
().
data
());
}
Block
*
ModuleOp
::
block
()
{
assert
(
operation
()
!=
nullptr
);
assert
(
operation
()
->
num_regions
()
==
1
);
assert
(
operation
()
->
GetRegion
(
0
).
size
()
==
1
);
return
operation
()
->
GetRegion
(
0
).
front
();
}
ModuleOp
ModuleOp
::
create
(
IrContext
*
context
,
Program
*
pointer
)
{
ir
::
OpInfo
info
=
context
->
GetRegisteredOpInfo
(
name
());
OperationArgument
argument
(
info
);
argument
.
AddRegion
()
->
emplace_back
();
argument
.
addAttribute
(
"program"
,
PointerAttribute
::
get
(
context
,
pointer
));
return
ModuleOp
(
Operation
::
create
(
std
::
move
(
argument
)));
}
void
ModuleOp
::
destroy
()
{
if
(
operation
())
{
operation
()
->
destroy
();
*
this
=
ModuleOp
(
nullptr
);
}
}
void
ModuleOp
::
verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: ModuleOp."
;
// Verify inputs type:
if
(
inputs
.
size
()
!=
0
)
{
throw
(
"The size of inputs must be equal to 0."
);
}
// Verify if attributes contain attribute name in attributes_name:
auto
iter
=
attributes
.
find
(
"program"
);
if
(
iter
==
attributes
.
end
()
||
!
iter
->
second
.
isa
<
PointerAttribute
>
())
{
throw
(
"Type of attribute: program is not right."
);
}
// Verify outputs type:
if
(
outputs
.
size
()
!=
0
)
{
throw
(
"The size of outputs must be equal to 0."
);
}
}
const
char
*
GetParameterOp
::
attributes_name
[
attributes_num
]
=
{
const
char
*
GetParameterOp
::
attributes_name
[
attributes_num
]
=
{
"parameter_name"
};
"parameter_name"
};
...
...
paddle/ir/core/builtin_op.h
浏览文件 @
49bedfd3
...
@@ -18,6 +18,31 @@
...
@@ -18,6 +18,31 @@
namespace
ir
{
namespace
ir
{
class
Program
;
class
Block
;
///
/// \brief ModuleOp
///
class
ModuleOp
:
public
ir
::
Op
<
ModuleOp
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"builtin.module"
;
}
static
constexpr
uint32_t
attributes_num
=
1
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
Program
*
program
();
Block
*
block
();
//
// As the top operation, ModuleOp only support create&destroye through
// below interface: "create"&"destroy".
static
ModuleOp
create
(
IrContext
*
context
,
Program
*
pointer
);
void
destroy
();
};
///
///
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// \brief GetParameterOp: OpResult = GetParameterOp({StrAttribute,
/// StrAttribute})
/// StrAttribute})
...
...
paddle/ir/core/op_base.h
浏览文件 @
49bedfd3
...
@@ -66,7 +66,7 @@ class InterfaceValue {
...
@@ -66,7 +66,7 @@ class InterfaceValue {
class
OpBase
{
class
OpBase
{
public:
public:
explicit
OpBase
(
Operation
*
operation
)
:
operation_
(
operation
)
{}
explicit
OpBase
(
Operation
*
operation
=
nullptr
)
:
operation_
(
operation
)
{}
Operation
*
operation
()
const
{
return
operation_
;
}
Operation
*
operation
()
const
{
return
operation_
;
}
...
@@ -76,6 +76,8 @@ class OpBase {
...
@@ -76,6 +76,8 @@ class OpBase {
Operation
*
operator
->
()
const
{
return
operation_
;
}
Operation
*
operator
->
()
const
{
return
operation_
;
}
IrContext
*
ir_context
()
const
{
return
operation_
->
ir_context
();
}
private:
private:
Operation
*
operation_
;
// Not owned
Operation
*
operation_
;
// Not owned
};
};
...
@@ -91,7 +93,7 @@ class OpTraitBase : public OpBase {
...
@@ -91,7 +93,7 @@ class OpTraitBase : public OpBase {
static
TypeId
GetTraitId
()
{
return
TypeId
::
get
<
ConcreteTrait
>
();
}
static
TypeId
GetTraitId
()
{
return
TypeId
::
get
<
ConcreteTrait
>
();
}
static
ConcreteTrait
dyn_cast
(
Operation
*
op
)
{
static
ConcreteTrait
dyn_cast
(
Operation
*
op
)
{
if
(
op
->
HasTrait
<
ConcreteTrait
>
())
{
if
(
op
&&
op
->
HasTrait
<
ConcreteTrait
>
())
{
return
ConcreteTrait
(
op
);
return
ConcreteTrait
(
op
);
}
}
return
ConcreteTrait
(
nullptr
);
return
ConcreteTrait
(
nullptr
);
...
@@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase {
...
@@ -109,7 +111,7 @@ class OpInterfaceBase : public OpBase {
static
TypeId
GetInterfaceId
()
{
return
TypeId
::
get
<
ConcreteInterface
>
();
}
static
TypeId
GetInterfaceId
()
{
return
TypeId
::
get
<
ConcreteInterface
>
();
}
static
ConcreteInterface
dyn_cast
(
Operation
*
op
)
{
static
ConcreteInterface
dyn_cast
(
Operation
*
op
)
{
if
(
op
->
HasInterface
<
ConcreteInterface
>
())
{
if
(
op
&&
op
->
HasInterface
<
ConcreteInterface
>
())
{
return
ConcreteInterface
(
return
ConcreteInterface
(
op
,
op
->
op_info
().
GetInterfaceImpl
<
ConcreteInterface
>
());
op
,
op
->
op_info
().
GetInterfaceImpl
<
ConcreteInterface
>
());
}
}
...
@@ -182,7 +184,7 @@ class Op : public OpBase {
...
@@ -182,7 +184,7 @@ class Op : public OpBase {
typename
Filter
<
OpInterfaceBase
,
std
::
tuple
<
TraitOrInterface
...
>>::
Type
;
typename
Filter
<
OpInterfaceBase
,
std
::
tuple
<
TraitOrInterface
...
>>::
Type
;
static
ConcreteOp
dyn_cast
(
Operation
*
op
)
{
static
ConcreteOp
dyn_cast
(
Operation
*
op
)
{
if
(
op
->
op_info
().
id
()
==
TypeId
::
get
<
ConcreteOp
>
())
{
if
(
op
&&
op
->
op_info
().
id
()
==
TypeId
::
get
<
ConcreteOp
>
())
{
return
ConcreteOp
(
op
);
return
ConcreteOp
(
op
);
}
}
return
ConcreteOp
(
nullptr
);
return
ConcreteOp
(
nullptr
);
...
...
paddle/ir/core/operation.cc
浏览文件 @
49bedfd3
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/region.h"
...
@@ -21,7 +22,7 @@
...
@@ -21,7 +22,7 @@
namespace
ir
{
namespace
ir
{
Operation
*
Operation
::
create
(
OperationArgument
&&
argument
)
{
Operation
*
Operation
::
create
(
OperationArgument
&&
argument
)
{
Operation
*
op
=
create
(
argument
.
inputs
,
Operation
*
op
=
create
(
argument
.
inputs
,
argument
.
attribute
,
argument
.
attribute
s
,
argument
.
output_types
,
argument
.
output_types
,
argument
.
info
,
argument
.
info
,
argument
.
regions
.
size
());
argument
.
regions
.
size
());
...
@@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) {
...
@@ -36,13 +37,13 @@ Operation *Operation::create(OperationArgument &&argument) {
// and operators, and construct it in the order of: OpOutlineResult,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
// OpInlineResult, Operation, Operand.
Operation
*
Operation
::
create
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
Operation
*
Operation
::
create
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
AttributeMap
&
attribute
,
const
AttributeMap
&
attribute
s
,
const
std
::
vector
<
ir
::
Type
>
&
output_types
,
const
std
::
vector
<
ir
::
Type
>
&
output_types
,
ir
::
OpInfo
op_info
,
ir
::
OpInfo
op_info
,
size_t
num_regions
)
{
size_t
num_regions
)
{
// 0. Verify
// 0. Verify
if
(
op_info
)
{
if
(
op_info
)
{
op_info
.
verify
(
inputs
,
output_types
,
attribute
);
op_info
.
verify
(
inputs
,
output_types
,
attribute
s
);
}
}
// 1. Calculate the required memory size for OpResults + Operation +
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
// OpOperands.
...
@@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
...
@@ -76,7 +77,7 @@ Operation *Operation::create(const std::vector<ir::OpResult> &inputs,
}
}
// 3.2. Construct Operation.
// 3.2. Construct Operation.
Operation
*
op
=
new
(
base_ptr
)
Operation
*
op
=
new
(
base_ptr
)
Operation
(
attribute
,
op_info
,
num_results
,
num_operands
,
num_regions
);
Operation
(
attribute
s
,
op_info
,
num_results
,
num_operands
,
num_regions
);
base_ptr
+=
sizeof
(
Operation
);
base_ptr
+=
sizeof
(
Operation
);
// 3.3. Construct OpOperands.
// 3.3. Construct OpOperands.
if
((
reinterpret_cast
<
uintptr_t
>
(
base_ptr
)
&
0x7
)
!=
0
)
{
if
((
reinterpret_cast
<
uintptr_t
>
(
base_ptr
)
&
0x7
)
!=
0
)
{
...
@@ -160,12 +161,12 @@ void Operation::destroy() {
...
@@ -160,12 +161,12 @@ void Operation::destroy() {
IrContext
*
Operation
::
ir_context
()
const
{
return
op_info_
.
ir_context
();
}
IrContext
*
Operation
::
ir_context
()
const
{
return
op_info_
.
ir_context
();
}
Operation
::
Operation
(
const
AttributeMap
&
attribute
,
Operation
::
Operation
(
const
AttributeMap
&
attribute
s
,
ir
::
OpInfo
op_info
,
ir
::
OpInfo
op_info
,
uint32_t
num_results
,
uint32_t
num_results
,
uint32_t
num_operands
,
uint32_t
num_operands
,
uint32_t
num_regions
)
uint32_t
num_regions
)
:
attribute
_
(
attribute
),
:
attribute
s_
(
attributes
),
op_info_
(
op_info
),
op_info_
(
op_info
),
num_results_
(
num_results
),
num_results_
(
num_results
),
num_operands_
(
num_operands
),
num_operands_
(
num_operands
),
...
@@ -223,6 +224,23 @@ std::string Operation::print() {
...
@@ -223,6 +224,23 @@ std::string Operation::print() {
std
::
string
Operation
::
op_name
()
const
{
return
op_info_
.
name
();
}
std
::
string
Operation
::
op_name
()
const
{
return
op_info_
.
name
();
}
Region
*
Operation
::
GetParentRegion
()
const
{
return
parent_
?
parent_
->
GetParentRegion
()
:
nullptr
;
}
Operation
*
Operation
::
GetParentOp
()
const
{
return
parent_
?
parent_
->
GetParentOp
()
:
nullptr
;
}
Program
*
Operation
::
GetParentProgram
()
{
Operation
*
op
=
this
;
while
(
Operation
*
parent_op
=
op
->
GetParentOp
())
{
op
=
parent_op
;
}
ModuleOp
module_op
=
op
->
dyn_cast
<
ModuleOp
>
();
return
module_op
?
module_op
.
program
()
:
nullptr
;
}
Region
&
Operation
::
GetRegion
(
unsigned
index
)
{
Region
&
Operation
::
GetRegion
(
unsigned
index
)
{
assert
(
index
<
num_regions_
&&
"invalid region index"
);
assert
(
index
<
num_regions_
&&
"invalid region index"
);
return
regions_
[
index
];
return
regions_
[
index
];
...
...
paddle/ir/core/operation.h
浏览文件 @
49bedfd3
...
@@ -34,7 +34,7 @@ class alignas(8) Operation final {
...
@@ -34,7 +34,7 @@ class alignas(8) Operation final {
/// used in conjunction.
/// used in conjunction.
///
///
static
Operation
*
create
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
static
Operation
*
create
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
AttributeMap
&
attribute
,
const
AttributeMap
&
attribute
s
,
const
std
::
vector
<
ir
::
Type
>
&
output_types
,
const
std
::
vector
<
ir
::
Type
>
&
output_types
,
ir
::
OpInfo
op_info
,
ir
::
OpInfo
op_info
,
size_t
num_regions
=
0
);
size_t
num_regions
=
0
);
...
@@ -45,8 +45,6 @@ class alignas(8) Operation final {
...
@@ -45,8 +45,6 @@ class alignas(8) Operation final {
///
///
void
destroy
();
void
destroy
();
Block
*
parent
()
const
{
return
parent_
;
}
IrContext
*
ir_context
()
const
;
IrContext
*
ir_context
()
const
;
ir
::
OpResult
GetResultByIndex
(
uint32_t
index
)
const
;
ir
::
OpResult
GetResultByIndex
(
uint32_t
index
)
const
;
...
@@ -55,7 +53,11 @@ class alignas(8) Operation final {
...
@@ -55,7 +53,11 @@ class alignas(8) Operation final {
std
::
string
print
();
std
::
string
print
();
const
AttributeMap
&
attribute
()
const
{
return
attribute_
;
}
const
AttributeMap
&
attributes
()
const
{
return
attributes_
;
}
void
SetAttribute
(
const
std
::
string
&
key
,
Attribute
value
)
{
attributes_
[
key
]
=
value
;
}
ir
::
OpInfo
op_info
()
const
{
return
op_info_
;
}
ir
::
OpInfo
op_info
()
const
{
return
op_info_
;
}
...
@@ -82,11 +84,13 @@ class alignas(8) Operation final {
...
@@ -82,11 +84,13 @@ class alignas(8) Operation final {
return
op_info_
.
HasInterface
<
Interface
>
();
return
op_info_
.
HasInterface
<
Interface
>
();
}
}
Program
*
parent_program
()
const
{
return
parent_program
_
;
}
Block
*
GetParentBlock
()
const
{
return
parent
_
;
}
void
set_parent_program
(
Program
*
parent_program
)
{
Region
*
GetParentRegion
()
const
;
parent_program_
=
parent_program
;
}
Operation
*
GetParentOp
()
const
;
Program
*
GetParentProgram
();
/// Returns the region held by this operation at position 'index'.
/// Returns the region held by this operation at position 'index'.
Region
&
GetRegion
(
unsigned
index
);
Region
&
GetRegion
(
unsigned
index
);
...
@@ -115,7 +119,7 @@ class alignas(8) Operation final {
...
@@ -115,7 +119,7 @@ class alignas(8) Operation final {
static
T
call
(
Operation
*
op
)
{
return
T
::
dyn_cast
(
op
);
}
static
T
call
(
Operation
*
op
)
{
return
T
::
dyn_cast
(
op
);
}
};
};
AttributeMap
attribute_
;
AttributeMap
attribute
s
_
;
OpInfo
op_info_
;
OpInfo
op_info_
;
...
@@ -124,7 +128,6 @@ class alignas(8) Operation final {
...
@@ -124,7 +128,6 @@ class alignas(8) Operation final {
const
uint32_t
num_regions_
=
0
;
const
uint32_t
num_regions_
=
0
;
Region
*
regions_
{
nullptr
};
Region
*
regions_
{
nullptr
};
Program
*
parent_program_
{
nullptr
};
Block
*
parent_
{
nullptr
};
Block
*
parent_
{
nullptr
};
};
};
...
...
paddle/ir/core/operation_utils.h
浏览文件 @
49bedfd3
...
@@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map<std::string, Attribute>;
...
@@ -32,7 +32,7 @@ using AttributeMap = std::unordered_map<std::string, Attribute>;
// with the builder APIs.
// with the builder APIs.
struct
OperationArgument
{
struct
OperationArgument
{
std
::
vector
<
OpResult
>
inputs
;
std
::
vector
<
OpResult
>
inputs
;
AttributeMap
attribute
;
AttributeMap
attribute
s
;
std
::
vector
<
Type
>
output_types
;
std
::
vector
<
Type
>
output_types
;
OpInfo
info
;
OpInfo
info
;
std
::
vector
<
std
::
unique_ptr
<
Region
>>
regions
;
std
::
vector
<
std
::
unique_ptr
<
Region
>>
regions
;
...
@@ -41,12 +41,12 @@ struct OperationArgument {
...
@@ -41,12 +41,12 @@ struct OperationArgument {
OperationArgument
(
IrContext
*
ir_context
,
const
std
::
string
&
name
);
OperationArgument
(
IrContext
*
ir_context
,
const
std
::
string
&
name
);
explicit
OperationArgument
(
OpInfo
info
)
:
info
(
info
)
{}
explicit
OperationArgument
(
OpInfo
info
)
:
info
(
info
)
{}
OperationArgument
(
const
std
::
vector
<
OpResult
>&
operands
,
OperationArgument
(
const
std
::
vector
<
OpResult
>&
operands
,
const
AttributeMap
&
named_attr
,
const
AttributeMap
&
attributes
,
const
std
::
vector
<
Type
>&
types
,
const
std
::
vector
<
Type
>&
types
,
OpInfo
info
,
OpInfo
info
,
std
::
vector
<
std
::
unique_ptr
<
Region
>>&&
regions
=
{})
std
::
vector
<
std
::
unique_ptr
<
Region
>>&&
regions
=
{})
:
inputs
(
operands
),
:
inputs
(
operands
),
attribute
(
named_attr
),
attribute
s
(
attributes
),
output_types
(
types
),
output_types
(
types
),
info
(
info
),
info
(
info
),
regions
(
std
::
move
(
regions
))
{}
regions
(
std
::
move
(
regions
))
{}
...
@@ -59,13 +59,18 @@ struct OperationArgument {
...
@@ -59,13 +59,18 @@ struct OperationArgument {
/// Add an attribute with the specified name.
/// Add an attribute with the specified name.
void
addAttribute
(
const
std
::
string
&
name
,
Attribute
attr
)
{
void
addAttribute
(
const
std
::
string
&
name
,
Attribute
attr
)
{
this
->
attribute
[
name
]
=
attr
;
attributes
[
name
]
=
attr
;
}
}
/// Add an array of named attributes.
/// Add an array of named attributes.
template
<
class
InputIt
>
template
<
class
InputIt
>
void
addAttributes
(
InputIt
first
,
InputIt
last
);
void
addAttributes
(
InputIt
first
,
InputIt
last
);
/// Get the context held by this operation state.
/// Get the context held by this operation state.
IrContext
*
getContext
()
const
{
return
info
.
ir_context
();
}
IrContext
*
getContext
()
const
{
return
info
.
ir_context
();
}
Region
*
AddRegion
()
{
regions
.
emplace_back
(
new
Region
);
return
regions
.
back
().
get
();
}
};
};
template
<
class
InputIt
>
template
<
class
InputIt
>
...
@@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) {
...
@@ -83,7 +88,7 @@ void OperationArgument::addTypes(InputIt first, InputIt last) {
template
<
class
InputIt
>
template
<
class
InputIt
>
void
OperationArgument
::
addAttributes
(
InputIt
first
,
InputIt
last
)
{
void
OperationArgument
::
addAttributes
(
InputIt
first
,
InputIt
last
)
{
while
(
first
!=
last
)
{
while
(
first
!=
last
)
{
attribute
[
first
->
first
]
=
first
->
second
;
attribute
s
[
first
->
first
]
=
first
->
second
;
++
first
;
++
first
;
}
}
}
}
...
...
paddle/ir/core/program.cc
浏览文件 @
49bedfd3
...
@@ -16,11 +16,17 @@
...
@@ -16,11 +16,17 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
namespace
ir
{
namespace
ir
{
Program
::~
Program
()
=
default
;
void
Program
::
InsertOp
(
Operation
*
op
)
{
Program
::
Program
(
IrContext
*
context
)
{
block_
.
push_back
(
op
);
module_
=
ModuleOp
::
create
(
context
,
this
);
op
->
set_parent_program
(
this
);
}
Program
::
Program
()
:
Program
(
IrContext
::
Instance
())
{}
Program
::~
Program
()
{
if
(
module_
)
{
module_
.
destroy
();
}
}
}
Parameter
*
Program
::
GetParameter
(
std
::
string
name
)
const
{
Parameter
*
Program
::
GetParameter
(
std
::
string
name
)
const
{
...
...
paddle/ir/core/program.h
浏览文件 @
49bedfd3
...
@@ -19,10 +19,13 @@
...
@@ -19,10 +19,13 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/parameter.h"
namespace
ir
{
namespace
ir
{
class
IrContext
;
///
///
/// \brief Program is an abstraction of model structure, divided into
/// \brief Program is an abstraction of model structure, divided into
/// computational graphs and weights. At the current stage, a computational
/// computational graphs and weights. At the current stage, a computational
...
@@ -33,27 +36,34 @@ namespace ir {
...
@@ -33,27 +36,34 @@ namespace ir {
///
///
class
Program
{
class
Program
{
public:
public:
using
ParameterMap
=
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Parameter
>>
;
explicit
Program
(
IrContext
*
context
);
Program
();
Program
(
Program
&&
)
=
delete
;
Program
(
const
Program
&
program
)
=
delete
;
Program
&
operator
=
(
const
Program
&
)
=
delete
;
Program
&
operator
=
(
Program
&&
);
~
Program
();
~
Program
();
Block
*
block
()
{
return
&
block_
;
}
size_t
parameters_num
()
const
{
return
parameters_
.
size
();
}
size_t
parameters_num
()
const
{
return
parameters_
.
size
();
}
///
ModuleOp
module_op
()
{
return
module_
;
}
/// \brief Insert the Operation* constructed by Operation::create(...) into
/// this Program. NOTE: At this time, the memory management permission of
/// Operation* will be owned by this Program. The user does not need to call
/// Operation::destroy() manually
///
void
InsertOp
(
Operation
*
op
);
Parameter
*
GetParameter
(
std
::
string
name
)
const
;
Block
*
block
()
{
return
module_
.
block
();
}
Parameter
*
GetParameter
(
std
::
string
name
)
const
;
void
SetParameter
(
std
::
string
name
,
std
::
unique_ptr
<
Parameter
>&&
parameter
);
void
SetParameter
(
std
::
string
name
,
std
::
unique_ptr
<
Parameter
>&&
parameter
);
ParameterMap
&
parameters
()
{
return
parameters_
;
}
void
set_parameters
(
ParameterMap
&&
parameters
)
{
parameters_
=
std
::
move
(
parameters
);
}
private:
private:
Block
block_
;
// computation graph
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
Parameter
>>
parameters_
;
ModuleOp
module_
;
// weight
ParameterMap
parameters_
;
};
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Program
&
program
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
Program
&
program
);
...
...
paddle/ir/core/region.cc
浏览文件 @
49bedfd3
...
@@ -22,6 +22,9 @@ void Region::push_back(Block *block) {
...
@@ -22,6 +22,9 @@ void Region::push_back(Block *block) {
block
->
set_parent
(
this
);
block
->
set_parent
(
this
);
blocks_
.
push_back
(
block
);
blocks_
.
push_back
(
block
);
}
}
void
Region
::
emplace_back
()
{
push_back
(
new
Block
);
}
void
Region
::
push_front
(
Block
*
block
)
{
void
Region
::
push_front
(
Block
*
block
)
{
block
->
set_parent
(
this
);
block
->
set_parent
(
this
);
blocks_
.
push_front
(
block
);
blocks_
.
push_front
(
block
);
...
...
paddle/ir/core/region.h
浏览文件 @
49bedfd3
...
@@ -41,12 +41,15 @@ class Region {
...
@@ -41,12 +41,15 @@ class Region {
Block
*
back
()
const
{
return
blocks_
.
back
();
}
Block
*
back
()
const
{
return
blocks_
.
back
();
}
Block
*
front
()
const
{
return
blocks_
.
front
();
}
Block
*
front
()
const
{
return
blocks_
.
front
();
}
void
push_back
(
Block
*
block
);
void
push_back
(
Block
*
block
);
void
emplace_back
();
void
push_front
(
Block
*
block
);
void
push_front
(
Block
*
block
);
iterator
insert
(
const_iterator
position
,
Block
*
block
);
iterator
insert
(
const_iterator
position
,
Block
*
block
);
void
clear
();
void
clear
();
void
TakeBody
(
Region
&&
other
);
void
TakeBody
(
Region
&&
other
);
Operation
*
GetParentOp
()
const
{
return
parent_
;
}
private:
private:
Region
(
Region
&
)
=
delete
;
Region
(
Region
&
)
=
delete
;
Region
&
operator
=
(
const
Region
&
)
=
delete
;
Region
&
operator
=
(
const
Region
&
)
=
delete
;
...
...
test/cpp/ir/core/ir_op_test.cc
浏览文件 @
49bedfd3
...
@@ -17,10 +17,12 @@
...
@@ -17,10 +17,12 @@
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/region.h"
/// \brief Define built-in Trait, derived from OpTraitBase.
/// \brief Define built-in Trait, derived from OpTraitBase.
...
@@ -133,7 +135,7 @@ class Operation2
...
@@ -133,7 +135,7 @@ class Operation2
throw
(
"Type of attribute: parameter_name is not right."
);
throw
(
"Type of attribute: parameter_name is not right."
);
}
}
}
}
static
void
InferShape
()
{
VLOG
(
0
)
<<
"This is op2's InferShape interface."
;
}
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
};
};
const
char
*
Operation2
::
attributes_name
[
attributes_num
]
=
{
"op2_attr1"
,
const
char
*
Operation2
::
attributes_name
[
attributes_num
]
=
{
"op2_attr1"
,
"op2_attr2"
};
"op2_attr2"
};
...
@@ -212,8 +214,8 @@ TEST(op_test, region_test) {
...
@@ -212,8 +214,8 @@ TEST(op_test, region_test) {
op1_info
);
op1_info
);
ir
::
OperationArgument
argument
(
op2_info
);
ir
::
OperationArgument
argument
(
op2_info
);
argument
.
attribute
=
CreateAttributeMap
({
"op2_attr1"
,
"op2_attr2"
},
argument
.
attribute
s
=
CreateAttributeMap
({
"op2_attr1"
,
"op2_attr2"
},
{
"op2_attr1"
,
"op2_attr2"
});
{
"op2_attr1"
,
"op2_attr2"
});
argument
.
output_types
=
{
ir
::
Float32Type
::
get
(
ctx
)};
argument
.
output_types
=
{
ir
::
Float32Type
::
get
(
ctx
)};
argument
.
regions
.
emplace_back
(
std
::
make_unique
<
ir
::
Region
>
());
argument
.
regions
.
emplace_back
(
std
::
make_unique
<
ir
::
Region
>
());
ir
::
Region
*
region
=
argument
.
regions
.
back
().
get
();
ir
::
Region
*
region
=
argument
.
regions
.
back
().
get
();
...
@@ -228,3 +230,26 @@ TEST(op_test, region_test) {
...
@@ -228,3 +230,26 @@ TEST(op_test, region_test) {
ir
::
Operation
*
op2
=
ir
::
Operation
::
create
(
std
::
move
(
argument
));
ir
::
Operation
*
op2
=
ir
::
Operation
::
create
(
std
::
move
(
argument
));
op2
->
destroy
();
op2
->
destroy
();
}
}
TEST
(
op_test
,
module_op_death
)
{
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ir
::
OpInfo
op_info
=
ctx
->
GetRegisteredOpInfo
(
ir
::
ModuleOp
::
name
());
// (3) Test uses for op.
std
::
vector
<
ir
::
OpResult
>
inputs
{
ir
::
OpResult
()};
ir
::
AttributeMap
attrs
{{
"program"
,
ir
::
Int32_tAttribute
::
get
(
ctx
,
1
)}};
std
::
vector
<
ir
::
Type
>
output_types
=
{
ir
::
Float32Type
::
get
(
ctx
)};
EXPECT_THROW
(
ir
::
Operation
::
create
(
inputs
,
{},
{},
op_info
),
const
char
*
);
EXPECT_THROW
(
ir
::
Operation
::
create
({},
attrs
,
{},
op_info
),
const
char
*
);
EXPECT_THROW
(
ir
::
Operation
::
create
({},
{},
output_types
,
op_info
),
const
char
*
);
ir
::
Program
program
(
ctx
);
EXPECT_EQ
(
program
.
module_op
().
program
(),
&
program
);
EXPECT_EQ
(
program
.
module_op
().
ir_context
(),
ctx
);
program
.
module_op
()
->
SetAttribute
(
"program"
,
ir
::
PointerAttribute
::
get
(
ctx
,
&
program
));
}
test/cpp/ir/core/ir_program_test.cc
浏览文件 @
49bedfd3
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
...
@@ -56,9 +57,7 @@ TEST(program_test, program) {
...
@@ -56,9 +57,7 @@ TEST(program_test, program) {
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
// (2) Create an empty program object
// (2) Create an empty program object
ir
::
Program
program
;
ir
::
Program
program
(
ctx
);
// ir::Program *program = new ir::Program();
EXPECT_EQ
(
program
.
block
()
->
size
()
==
0
,
true
);
// (3) Create a float32 DenseTensor Parameter and save into Program
// (3) Create a float32 DenseTensor Parameter and save into Program
ir
::
Type
fp32_dtype
=
ir
::
Float32Type
::
get
(
ctx
);
ir
::
Type
fp32_dtype
=
ir
::
Float32Type
::
get
(
ctx
);
...
@@ -94,7 +93,14 @@ TEST(program_test, program) {
...
@@ -94,7 +93,14 @@ TEST(program_test, program) {
ir
::
Operation
*
op1
=
ir
::
Operation
*
op1
=
ir
::
Operation
::
create
({},
op1_attribute
,
{
dense_tensor_dtype
},
op1_info
);
ir
::
Operation
::
create
({},
op1_attribute
,
{
dense_tensor_dtype
},
op1_info
);
program
.
InsertOp
(
op1
);
ir
::
Block
*
block
=
program
.
block
();
block
->
push_back
(
op1
);
EXPECT_EQ
(
&
program
.
module_op
()
->
GetRegion
(
0
),
block
->
GetParentRegion
());
EXPECT_EQ
(
program
.
module_op
(),
block
->
GetParentOp
());
EXPECT_EQ
(
&
program
,
op1
->
GetParentProgram
());
EXPECT_EQ
(
op1
->
GetResultByIndex
(
0
).
type
().
dialect
().
id
(),
EXPECT_EQ
(
op1
->
GetResultByIndex
(
0
).
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
paddle_dialect
->
id
());
...
@@ -124,7 +130,7 @@ TEST(program_test, program) {
...
@@ -124,7 +130,7 @@ TEST(program_test, program) {
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"b"
)}};
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"b"
)}};
ir
::
Operation
*
op2
=
ir
::
Operation
*
op2
=
ir
::
Operation
::
create
({},
op2_attribute
,
{
dense_tensor_dtype
},
op2_info
);
ir
::
Operation
::
create
({},
op2_attribute
,
{
dense_tensor_dtype
},
op2_info
);
program
.
InsertOp
(
op2
);
block
->
push_back
(
op2
);
EXPECT_EQ
(
op2
->
GetResultByIndex
(
0
).
type
().
dialect
().
id
(),
EXPECT_EQ
(
op2
->
GetResultByIndex
(
0
).
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
paddle_dialect
->
id
());
...
@@ -155,7 +161,7 @@ TEST(program_test, program) {
...
@@ -155,7 +161,7 @@ TEST(program_test, program) {
op3_attribute
,
op3_attribute
,
{
dense_tensor_dtype
},
{
dense_tensor_dtype
},
op3_info
);
op3_info
);
program
.
InsertOp
(
op3
);
block
->
push_back
(
op3
);
phi
::
CPUContext
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
phi
::
CPUContext
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
...
@@ -196,9 +202,12 @@ TEST(program_test, program) {
...
@@ -196,9 +202,12 @@ TEST(program_test, program) {
ir
::
OpInfo
op4_info
=
ctx
->
GetRegisteredOpInfo
(
op4_name
);
ir
::
OpInfo
op4_info
=
ctx
->
GetRegisteredOpInfo
(
op4_name
);
std
::
unordered_map
<
std
::
string
,
ir
::
Attribute
>
op4_attribute
{
std
::
unordered_map
<
std
::
string
,
ir
::
Attribute
>
op4_attribute
{
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"c"
)}};
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"c"
)}};
ir
::
Operation
*
op4
=
ir
::
Operation
::
create
(
{
op3
->
GetResultByIndex
(
0
)},
op4_attribute
,
{},
op4_info
);
ir
::
OperationArgument
op4_argument
(
program
.
InsertOp
(
op4
);
{
op3
->
GetResultByIndex
(
0
)},
{},
{},
op4_info
);
op4_argument
.
addAttributes
(
op4_attribute
.
begin
(),
op4_attribute
.
end
());
ir
::
Operation
*
op4
=
ir
::
Operation
::
create
(
std
::
move
(
op4_argument
));
block
->
push_back
(
op4
);
EXPECT_EQ
(
op4
->
GetOperandByIndex
(
0
).
impl
()
->
source
().
type
().
dialect
().
id
(),
EXPECT_EQ
(
op4
->
GetOperandByIndex
(
0
).
impl
()
->
source
().
type
().
dialect
().
id
(),
paddle_dialect
->
id
());
paddle_dialect
->
id
());
...
@@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) {
...
@@ -244,7 +253,7 @@ TEST(program_test, slice_combine_test) {
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"a"
)}};
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"a"
)}};
ir
::
Operation
*
op1
=
ir
::
Operation
*
op1
=
ir
::
Operation
::
create
({},
op1_attribute
,
{
fp32_dtype
},
op1_info
);
ir
::
Operation
::
create
({},
op1_attribute
,
{
fp32_dtype
},
op1_info
);
program
.
InsertOp
(
op1
);
program
.
block
()
->
push_back
(
op1
);
// (5) Def b = GetParameterOp("b")
// (5) Def b = GetParameterOp("b")
std
::
string
op2_name
=
std
::
string
(
ir
::
GetParameterOp
::
name
());
std
::
string
op2_name
=
std
::
string
(
ir
::
GetParameterOp
::
name
());
...
@@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) {
...
@@ -253,7 +262,7 @@ TEST(program_test, slice_combine_test) {
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"b"
)}};
{
"parameter_name"
,
ir
::
StrAttribute
::
get
(
ctx
,
"b"
)}};
ir
::
Operation
*
op2
=
ir
::
Operation
*
op2
=
ir
::
Operation
::
create
({},
op2_attribute
,
{
fp32_dtype
},
op2_info
);
ir
::
Operation
::
create
({},
op2_attribute
,
{
fp32_dtype
},
op2_info
);
program
.
InsertOp
(
op2
);
program
.
block
()
->
push_back
(
op2
);
// (6) Def combine_op = CombineOp("a", "b")
// (6) Def combine_op = CombineOp("a", "b")
std
::
string
combine_op_name
=
std
::
string
(
ir
::
CombineOp
::
name
());
std
::
string
combine_op_name
=
std
::
string
(
ir
::
CombineOp
::
name
());
...
@@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) {
...
@@ -265,7 +274,7 @@ TEST(program_test, slice_combine_test) {
{},
{},
{
output_type
},
{
output_type
},
combine_op_info
);
combine_op_info
);
program
.
InsertOp
(
combine_op
);
program
.
block
()
->
push_back
(
combine_op
);
// (7) Def slice_op = SliceOp(combine_op, 0)
// (7) Def slice_op = SliceOp(combine_op, 0)
std
::
string
slice_op_name
=
std
::
string
(
ir
::
SliceOp
::
name
());
std
::
string
slice_op_name
=
std
::
string
(
ir
::
SliceOp
::
name
());
...
@@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) {
...
@@ -276,7 +285,7 @@ TEST(program_test, slice_combine_test) {
{{
"index"
,
index_attr
}},
{{
"index"
,
index_attr
}},
{
fp32_dtype
},
{
fp32_dtype
},
slice_op_info
);
slice_op_info
);
program
.
InsertOp
(
slice_op
);
program
.
block
()
->
push_back
(
slice_op
);
// (8) Traverse Program
// (8) Traverse Program
EXPECT_EQ
(
program
.
block
()
->
size
()
==
4
,
true
);
EXPECT_EQ
(
program
.
block
()
->
size
()
==
4
,
true
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录