Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a25f013b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a25f013b
编写于
8月 07, 2023
作者:
Z
zhangbo9674
提交者:
GitHub
8月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR] Sovle bugs (#55991)
* sovle conflict bug * fix bug
上级
ddfbf135
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
174 addition
and
76 deletion
+174
-76
paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc
...mework/new_executor/instruction/phi_kernel_instruction.cc
+1
-1
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
...id/framework/new_executor/interpreter/interpreter_util.cc
+13
-4
paddle/fluid/framework/new_executor/interpreter/interpreter_util.h
...uid/framework/new_executor/interpreter/interpreter_util.h
+2
-0
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
+45
-43
paddle/fluid/framework/new_executor/new_ir_interpreter.h
paddle/fluid/framework/new_executor/new_ir_interpreter.h
+3
-1
paddle/fluid/framework/new_executor/program_interpreter.cc
paddle/fluid/framework/new_executor/program_interpreter.cc
+12
-0
paddle/fluid/ir/interface/op_yaml_info_parser.cc
paddle/fluid/ir/interface/op_yaml_info_parser.cc
+22
-0
paddle/fluid/ir/interface/op_yaml_info_parser.h
paddle/fluid/ir/interface/op_yaml_info_parser.h
+4
-0
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
+1
-2
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
+11
-14
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
+13
-8
paddle/fluid/ir_adaptor/translator/op_translator.cc
paddle/fluid/ir_adaptor/translator/op_translator.cc
+1
-3
paddle/fluid/ir_adaptor/translator/program_translator.cc
paddle/fluid/ir_adaptor/translator/program_translator.cc
+44
-0
paddle/fluid/ir_adaptor/translator/program_translator.h
paddle/fluid/ir_adaptor/translator/program_translator.h
+1
-0
paddle/ir/core/attribute.h
paddle/ir/core/attribute.h
+1
-0
未找到文件。
paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc
浏览文件 @
a25f013b
...
...
@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
std
::
unordered_map
<
ir
::
Value
,
std
::
vector
<
int
>>
outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
i
++
)
{
ir
::
Value
value
=
op
->
result
(
i
);
if
(
value
)
{
if
(
value
&&
value
.
type
()
)
{
PADDLE_ENFORCE_NE
(
value_2_var_name
.
find
(
value
),
value_2_var_name
.
end
(),
...
...
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
浏览文件 @
a25f013b
...
...
@@ -977,10 +977,7 @@ void BuildOpFuncList(
attr_map
.
at
(
"op_name"
).
dyn_cast
<::
ir
::
StrAttribute
>
().
AsString
();
op_func_node
.
phi_op_name_
=
op_name
;
if
(
op_name
==
"builtin.combine"
||
op_name
==
"pd.feed"
||
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
||
op_name
==
"pd.data"
||
op_name
==
"pd.shadow_output"
)
{
if
(
GetSpecialOpNames
().
count
(
op_name
))
{
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
continue
;
}
...
...
@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op,
}
}
std
::
unordered_set
<
std
::
string
>
GetSpecialOpNames
()
{
return
{
"builtin.combine"
,
"builtin.slice"
,
"pd.feed"
,
"builtin.set_parameter"
,
"builtin.get_parameter"
,
"pd.data"
,
"pd.shadow_output"
,
};
}
}
// namespace interpreter
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/interpreter/interpreter_util.h
浏览文件 @
a25f013b
...
...
@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
void
SetDeviceCommContext
(
::
ir
::
Operation
*
op
,
platform
::
DeviceContext
*
dev_ctx
);
std
::
unordered_set
<
std
::
string
>
GetSpecialOpNames
();
}
// namespace interpreter
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
浏览文件 @
a25f013b
...
...
@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&
value_2_var_name_
,
&
variable_2_var_name_
,
&
var_name_2_id_
,
&
variable_list_
,
&
parameter_values_
);
&
variable_list_
);
VLOG
(
4
)
<<
DebugValueInfo
();
SolvePersisableVarNames
();
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
interpreter
::
BuildOpFuncList
(
place_
,
ir_program_
->
block
(),
...
...
@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
std
::
stringstream
ss
;
ss
<<
"trace order: "
;
for
(
size_t
idx
=
0
;
idx
<
trace_execute_order_
.
size
();
idx
++
)
{
ss
<<
trace_execute_order_
[
idx
]
<<
" -> "
;
ss
<<
vec_instruction_base_
[
trace_execute_order_
[
idx
]]
->
Name
()
<<
"["
<<
trace_execute_order_
[
idx
]
<<
"]"
<<
" -> "
;
}
ss
<<
"end
\n
"
;
VLOG
(
6
)
<<
ss
.
str
();
...
...
@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
.
at
(
"op_name"
)
.
dyn_cast
<::
ir
::
StrAttribute
>
()
.
AsString
();
if
(
op_name
==
"builtin.combine"
||
op_name
==
"pd.feed"
||
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
||
op_name
==
"pd.data"
||
op_name
==
"pd.shadow_output"
)
{
if
(
interpreter
::
GetSpecialOpNames
().
count
(
op_name
))
{
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
continue
;
}
...
...
@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
VLOG
(
4
)
<<
"GC sync "
<<
GetNameById
(
var_id
);
// persistable var will be ignore while GC
::
ir
::
Value
value
=
GetValueByName
(
GetNameById
(
var_id
));
bool
is_parameter
=
false
;
if
(
value
)
{
for
(
auto
item
:
parameter_values_
)
{
if
(
item
==
value
)
{
is_parameter
=
true
;
break
;
}
}
}
if
(
is_parameter
)
{
VLOG
(
4
)
<<
"value "
<<
value
.
impl
()
<<
" is a parameter, skip gc"
;
if
(
parameter_var_names_
.
count
(
GetNameById
(
var_id
)))
{
VLOG
(
4
)
<<
GetNameById
(
var_id
)
<<
" is a parameter, skip gc"
;
continue
;
}
...
...
@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
<<
", ref:"
<<
refs_
[
var_id
]
->
DynamicRef
();
bool
is_ready
=
refs_
[
var_id
]
->
CheckAndDecrease
();
// ignore all persistable var while GCphi
::
ir
::
Value
value
=
GetValueByName
(
GetNameById
(
var_id
));
bool
is_parameter
=
false
;
if
(
value
)
{
for
(
auto
item
:
parameter_values_
)
{
if
(
item
==
value
)
{
is_parameter
=
true
;
break
;
}
}
}
if
(
is_parameter
)
{
VLOG
(
4
)
<<
"value "
<<
value
.
impl
()
<<
" is a parameter, skip gc"
;
if
(
parameter_var_names_
.
count
(
GetNameById
(
var_id
)))
{
VLOG
(
4
)
<<
GetNameById
(
var_id
)
<<
" is a parameter, skip gc"
;
continue
;
}
...
...
@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&
value_2_var_name_
,
&
variable_2_var_name_
,
&
var_name_2_id_
,
&
variable_list_
,
&
parameter_values_
);
&
variable_list_
);
VLOG
(
4
)
<<
"Done BuildScope"
;
VLOG
(
4
)
<<
DebugValueInfo
();
SolvePersisableVarNames
();
VLOG
(
4
)
<<
"Parameter value include: "
;
for
(
auto
parameter
:
parameter_var_names_
)
{
VLOG
(
4
)
<<
"Parameter value: "
<<
parameter
;
}
BuildInstruction
();
VLOG
(
4
)
<<
"Done BuildInstruction"
;
...
...
@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG
(
4
)
<<
"Done PreAnalysis"
;
// Run
if
(
FLAGS_enable_new_ir_in_executor_loop_run
)
{
LOG_FIRST_N
(
INFO
,
1
)
<<
"New ir interpreter is running in BetaRun mode "
"with for_loop version
."
;
"with for_loop version(First step)
."
;
LoopRunImpl
();
}
else
{
LOG_FIRST_N
(
INFO
,
1
)
<<
"New ir interpreter is running in BetaRun mode "
"with trace version."
;
TraceRunImpl
();
}
is_build_
=
true
;
}
else
{
if
(
FLAGS_enable_new_ir_in_executor_loop_run
)
{
...
...
@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList(
auto
instr_id
=
trace_execute_order_
[
idx
];
InstructionBase
*
instr_node
=
vec_instruction_base_
.
at
(
instr_id
).
get
();
VLOG
(
6
)
<<
"Run InstructionBase "
<<
instr_id
;
VLOG
(
6
)
<<
"Run InstructionBase "
<<
instr_node
->
Name
()
<<
"["
<<
instr_id
<<
"]"
;
RunInstructionBase
(
instr_node
);
if
(
UNLIKELY
(
exception_holder_
.
IsCaught
()))
{
...
...
@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() {
return
nullptr
;
}
void
NewIRInterpreter
::
SolvePersisableVarNames
()
{
VLOG
(
6
)
<<
"SolvePersisableVarNames"
;
for
(
auto
kv
:
value_2_var_name_
)
{
::
ir
::
Value
value
=
kv
.
first
;
std
::
string
var_name
=
kv
.
second
;
::
ir
::
OpResult
result
=
value
.
dyn_cast
<::
ir
::
OpResult
>
();
auto
*
defining_op
=
value
.
GetDefiningOp
();
if
(
defining_op
->
HasAttribute
(
kAttrIsPersisable
))
{
auto
is_persisables
=
defining_op
->
attribute
(
kAttrIsPersisable
)
.
dyn_cast
<::
ir
::
ArrayAttribute
>
()
.
AsVector
();
if
(
is_persisables
[
result
.
GetResultIndex
()]
.
dyn_cast
<::
ir
::
BoolAttribute
>
()
.
data
())
{
VLOG
(
6
)
<<
"parameter_var_names_ include: "
<<
var_name
;
parameter_var_names_
.
insert
(
var_name
);
}
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/new_executor/new_ir_interpreter.h
浏览文件 @
a25f013b
...
...
@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void
RecordStreamForGC
(
InstructionBase
*
instr
);
void
SolvePersisableVarNames
();
InstructionSchedulingPriorityLess
ir_instruction_scheduling_priority_less
;
std
::
unique_ptr
<::
ir
::
Program
>
ir_program_
{
nullptr
};
...
...
@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// belongs to a parameter and cannot GC.
std
::
vector
<::
ir
::
Value
>
parameter_valu
es_
;
std
::
unordered_set
<
std
::
string
>
parameter_var_nam
es_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/new_executor/program_interpreter.cc
浏览文件 @
a25f013b
...
...
@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
"trace_order size should be equal to dependecy_count_."
));
trace_execute_order_
=
trace_order
;
std
::
stringstream
ss
;
ss
<<
"trace order: "
;
for
(
size_t
idx
=
0
;
idx
<
trace_execute_order_
.
size
();
idx
++
)
{
ss
<<
vec_instruction_
[
trace_execute_order_
[
idx
]]
.
OpFunc
()
->
operator_base_
->
Type
()
<<
"["
<<
trace_execute_order_
[
idx
]
<<
"]"
<<
" -> "
;
}
ss
<<
"end
\n
"
;
VLOG
(
6
)
<<
ss
.
str
();
}
}
// namespace framework
...
...
paddle/fluid/ir/interface/op_yaml_info_parser.cc
浏览文件 @
a25f013b
...
...
@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName(
"Can not find inplace input of [%s]."
,
out_name
));
}
bool
OpYamlInfoParser
::
HasView
(
const
std
::
string
&
out_name
)
const
{
auto
&
view_info
=
std
::
get
<
3
>
(
op_info_tuple_
).
view
;
for
(
size_t
i
=
0
;
i
<
view_info
.
size
();
i
++
)
{
if
(
out_name
==
view_info
[
i
].
first
)
{
return
true
;
}
}
return
false
;
}
const
std
::
string
&
OpYamlInfoParser
::
ViewName
(
const
std
::
string
&
out_name
)
const
{
auto
&
view_info
=
std
::
get
<
3
>
(
op_info_tuple_
).
view
;
for
(
size_t
i
=
0
;
i
<
view_info
.
size
();
i
++
)
{
if
(
out_name
==
view_info
[
i
].
first
)
{
return
view_info
[
i
].
second
;
}
}
PADDLE_THROW
(
phi
::
errors
::
PreconditionNotMet
(
"Can not find inplace input of [%s]."
,
out_name
));
}
void
OpYamlInfoParser
::
parse
()
{
auto
input_info
=
std
::
get
<
0
>
(
op_info_tuple_
);
...
...
paddle/fluid/ir/interface/op_yaml_info_parser.h
浏览文件 @
a25f013b
...
...
@@ -53,6 +53,10 @@ class OpYamlInfoParser {
const
std
::
string
&
InplaceName
(
const
std
::
string
&
out_name
)
const
;
bool
HasView
(
const
std
::
string
&
out_name
)
const
;
const
std
::
string
&
ViewName
(
const
std
::
string
&
out_name
)
const
;
private:
void
parse
();
inline
const
std
::
vector
<
OpInputInfo
>&
InputInfo
()
const
{
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
浏览文件 @
a25f013b
...
...
@@ -69,8 +69,7 @@ class PhiKernelAdaptor {
&
value_2_var_name
,
&
variable_2_var_name
,
&
var_name_2_id
,
&
variable_list
,
nullptr
);
&
variable_list
);
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
浏览文件 @
a25f013b
...
...
@@ -217,8 +217,7 @@ void HandleForSpecialOp(
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<::
ir
::
Value
>*
parameter_values
)
{
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
)
{
std
::
string
op_name
=
op
->
name
();
if
(
op
->
attributes
().
count
(
"op_name"
))
{
op_name
=
...
...
@@ -347,10 +346,6 @@ void HandleForSpecialOp(
value_2_var_name
,
variable_2_var_name
,
var_name_2_id
);
if
(
parameter_values
)
{
parameter_values
->
push_back
(
value
);
}
}
if
(
op_name
==
"pd.shadow_output"
)
{
...
...
@@ -390,10 +385,6 @@ void HandleForSpecialOp(
variable_2_var_name
,
var_name_2_id
,
variable_list
);
if
(
parameter_values
)
{
parameter_values
->
push_back
(
value
);
}
}
if
(
op_name
==
"builtin.slice"
)
{
...
...
@@ -458,6 +449,14 @@ void HandleForInplaceOp(
VLOG
(
4
)
<<
"inplace: "
<<
value_name
<<
" -> "
<<
inplace_name
<<
" (var: "
<<
var_name
<<
")"
;
value_2_var_name
->
emplace
(
value
,
var_name
);
}
else
if
(
yaml_parser
.
HasView
(
value_name
))
{
std
::
string
view_name
=
yaml_parser
.
ViewName
(
value_name
);
ir
::
Value
view_value
=
op
->
operand_source
(
yaml_parser
.
InputName2Id
().
at
(
view_name
));
std
::
string
var_name
=
value_2_var_name
->
at
(
view_value
);
VLOG
(
4
)
<<
"view: "
<<
value_name
<<
" -> "
<<
view_name
<<
" (var: "
<<
var_name
<<
")"
;
value_2_var_name
->
emplace
(
value
,
var_name
);
}
else
{
BuildValue
(
value
,
inner_scope
,
...
...
@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<::
ir
::
Value
>*
parameter_values
)
{
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
)
{
VLOG
(
4
)
<<
"***** [before build] scope"
<<
"("
<<
inner_scope
<<
") ******
\n
"
<<
paddle
::
framework
::
GenScopeTreeDebugInfo
(
...
...
@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block,
value_2_var_name
,
variable_2_var_name
,
var_name_2_id
,
variable_list
,
parameter_values
);
variable_list
);
continue
;
}
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
浏览文件 @
a25f013b
...
...
@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<::
ir
::
Value
>*
parameter_values
);
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
);
void
BuildRuntimeContext
(
ir
::
Operation
*
op
,
...
...
@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op,
// TODO(phlrain): use var type instead of op name
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
ir
::
Value
out_ptr
=
op
->
result
(
i
);
auto
out_type
=
out_ptr
.
type
();
if
(
out_type
)
{
auto
name
=
name_map
.
at
(
out_ptr
);
VLOG
(
6
)
<<
"ctx->EmplaceBackOutput: "
<<
name
;
auto
out_type
=
out_ptr
.
type
();
}
else
{
VLOG
(
6
)
<<
"ctx->EmplaceBackOutput : an optioanl output"
;
}
if
(
!
out_type
)
{
phi
::
DenseTensor
*
ptr
=
nullptr
;
OutType
out_ptr
(
ptr
);
ctx
->
EmplaceBackOutput
(
out_ptr
);
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedDenseTensorType
>
())
{
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
inner_scope
->
FindVar
(
name
)
->
Get
<
phi
::
DenseTensor
>
()))));
&
(
inner_scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
->
Get
<
phi
::
DenseTensor
>
()))));
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedSelectedRowsType
>
())
{
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
SelectedRows
*>
(
&
(
inner_scope
->
FindVar
(
name
)
->
Get
<
phi
::
SelectedRows
>
()))));
&
(
inner_scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
->
Get
<
phi
::
SelectedRows
>
()))));
}
else
if
(
out_type
.
isa
<
ir
::
VectorType
>
())
{
OutListType
outputs
;
auto
&
variable_array
=
scope
->
FindVar
(
name
)
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
auto
&
variable_array
=
scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
for
(
size_t
i
=
0
;
i
<
variable_array
.
size
();
++
i
)
{
outputs
.
emplace_back
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
variable_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
()))));
...
...
paddle/fluid/ir_adaptor/translator/op_translator.cc
浏览文件 @
a25f013b
...
...
@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute(
constexpr
char
kTargetDialectPrefix
[]
=
"pd."
;
constexpr
char
kEmptyVarName
[]
=
"@EMPTY@"
;
static
const
std
::
unordered_set
<
std
::
string
>
special_non_inplace_ops
=
{
"batch_norm"
,
};
static
const
std
::
unordered_set
<
std
::
string
>
special_non_inplace_ops
=
{};
static
const
std
::
unordered_set
<
std
::
string
>
special_inplace_ops
=
{
"adagrad"
,
...
...
paddle/fluid/ir_adaptor/translator/program_translator.cc
浏览文件 @
a25f013b
...
...
@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() {
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetStopGradientAttributeForAllValue
(
block
);
}
for
(
size_t
block_idx
=
0
;
block_idx
<
legacy_program_
->
Size
();
block_idx
++
)
{
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetIsPersisableAttributeForAllValue
(
block
);
}
}
inline
ir
::
Operation
*
InsertGetParamaterOp
(
ir
::
IrContext
*
ctx
,
...
...
@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
}
}
void
ProgramTranslator
::
SetIsPersisableAttributeForAllValue
(
const
BlockDesc
&
block
)
{
// Currently we set is persisable for operation that generated a value
// connected with VarDesc
for
(
const
auto
&
[
var_name
,
value_info
]
:
param_map_
)
{
if
(
no_cast_var_names
.
count
(
var_name
)
!=
0
)
continue
;
VLOG
(
10
)
<<
"[op translated][is persisable]"
<<
var_name
;
VarDesc
*
var
=
block
.
FindVarRecursive
(
var_name
);
if
(
var
==
nullptr
)
{
continue
;
}
ir
::
OpResult
value
=
value_info
.
value
;
if
(
!
value
)
{
PADDLE_THROW
(
phi
::
errors
::
PreconditionNotMet
(
"Value of [%s] can not ber None"
,
var_name
));
}
auto
*
defining_op
=
value
.
owner
();
PADDLE_ENFORCE_NOT_NULL
(
defining_op
,
phi
::
errors
::
PreconditionNotMet
(
"Defining operator of [%s] can not be nullptr"
,
var_name
));
VLOG
(
8
)
<<
"[op translated][is persisable]"
<<
var_name
<<
" from: "
<<
defining_op
->
name
();
std
::
vector
<
ir
::
Attribute
>
is_persisable
;
if
(
defining_op
->
HasAttribute
(
kAttrIsPersisable
))
{
is_persisable
=
defining_op
->
attribute
(
kAttrIsPersisable
)
.
dyn_cast
<
ir
::
ArrayAttribute
>
()
.
AsVector
();
}
else
{
is_persisable
=
std
::
vector
<
ir
::
Attribute
>
(
defining_op
->
num_results
(),
ir
::
BoolAttribute
::
get
(
ctx_
,
false
));
}
is_persisable
[
value
.
GetResultIndex
()]
=
ir
::
BoolAttribute
::
get
(
ctx_
,
var
->
Persistable
());
defining_op
->
set_attribute
(
kAttrIsPersisable
,
ir
::
ArrayAttribute
::
get
(
ctx_
,
is_persisable
));
}
}
}
// namespace translator
}
// namespace paddle
paddle/fluid/ir_adaptor/translator/program_translator.h
浏览文件 @
a25f013b
...
...
@@ -79,6 +79,7 @@ class ProgramTranslator {
void
InsertOperationToSingleBlock
(
const
BlockDesc
&
block
);
void
SetParameterFromSingleBlock
(
const
BlockDesc
&
block
);
void
SetStopGradientAttributeForAllValue
(
const
BlockDesc
&
block
);
void
SetIsPersisableAttributeForAllValue
(
const
BlockDesc
&
block
);
};
}
// namespace translator
...
...
paddle/ir/core/attribute.h
浏览文件 @
a25f013b
...
...
@@ -18,6 +18,7 @@
#include "paddle/ir/core/type_id.h"
constexpr
char
kAttrStopGradients
[]
=
"stop_gradient"
;
constexpr
char
kAttrIsPersisable
[]
=
"is_persisable"
;
namespace
ir
{
class
AttributeStorage
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录