Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a25f013b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a25f013b
编写于
8月 07, 2023
作者:
Z
zhangbo9674
提交者:
GitHub
8月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR] Sovle bugs (#55991)
* sovle conflict bug * fix bug
上级
ddfbf135
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
174 addition
and
76 deletion
+174
-76
paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc
...mework/new_executor/instruction/phi_kernel_instruction.cc
+1
-1
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
...id/framework/new_executor/interpreter/interpreter_util.cc
+13
-4
paddle/fluid/framework/new_executor/interpreter/interpreter_util.h
...uid/framework/new_executor/interpreter/interpreter_util.h
+2
-0
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
+45
-43
paddle/fluid/framework/new_executor/new_ir_interpreter.h
paddle/fluid/framework/new_executor/new_ir_interpreter.h
+3
-1
paddle/fluid/framework/new_executor/program_interpreter.cc
paddle/fluid/framework/new_executor/program_interpreter.cc
+12
-0
paddle/fluid/ir/interface/op_yaml_info_parser.cc
paddle/fluid/ir/interface/op_yaml_info_parser.cc
+22
-0
paddle/fluid/ir/interface/op_yaml_info_parser.h
paddle/fluid/ir/interface/op_yaml_info_parser.h
+4
-0
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
+1
-2
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
+11
-14
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
+13
-8
paddle/fluid/ir_adaptor/translator/op_translator.cc
paddle/fluid/ir_adaptor/translator/op_translator.cc
+1
-3
paddle/fluid/ir_adaptor/translator/program_translator.cc
paddle/fluid/ir_adaptor/translator/program_translator.cc
+44
-0
paddle/fluid/ir_adaptor/translator/program_translator.h
paddle/fluid/ir_adaptor/translator/program_translator.h
+1
-0
paddle/ir/core/attribute.h
paddle/ir/core/attribute.h
+1
-0
未找到文件。
paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc
浏览文件 @
a25f013b
...
@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
...
@@ -333,7 +333,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
std
::
unordered_map
<
ir
::
Value
,
std
::
vector
<
int
>>
outputs
;
std
::
unordered_map
<
ir
::
Value
,
std
::
vector
<
int
>>
outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
i
++
)
{
ir
::
Value
value
=
op
->
result
(
i
);
ir
::
Value
value
=
op
->
result
(
i
);
if
(
value
)
{
if
(
value
&&
value
.
type
()
)
{
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
value_2_var_name
.
find
(
value
),
value_2_var_name
.
find
(
value
),
value_2_var_name
.
end
(),
value_2_var_name
.
end
(),
...
...
paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
浏览文件 @
a25f013b
...
@@ -977,10 +977,7 @@ void BuildOpFuncList(
...
@@ -977,10 +977,7 @@ void BuildOpFuncList(
attr_map
.
at
(
"op_name"
).
dyn_cast
<::
ir
::
StrAttribute
>
().
AsString
();
attr_map
.
at
(
"op_name"
).
dyn_cast
<::
ir
::
StrAttribute
>
().
AsString
();
op_func_node
.
phi_op_name_
=
op_name
;
op_func_node
.
phi_op_name_
=
op_name
;
if
(
op_name
==
"builtin.combine"
||
op_name
==
"pd.feed"
||
if
(
GetSpecialOpNames
().
count
(
op_name
))
{
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
||
op_name
==
"pd.data"
||
op_name
==
"pd.shadow_output"
)
{
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
continue
;
continue
;
}
}
...
@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op,
...
@@ -1171,6 +1168,18 @@ void SetDeviceCommContext(::ir::Operation* op,
}
}
}
}
std
::
unordered_set
<
std
::
string
>
GetSpecialOpNames
()
{
return
{
"builtin.combine"
,
"builtin.slice"
,
"pd.feed"
,
"builtin.set_parameter"
,
"builtin.get_parameter"
,
"pd.data"
,
"pd.shadow_output"
,
};
}
}
// namespace interpreter
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/interpreter/interpreter_util.h
浏览文件 @
a25f013b
...
@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
...
@@ -124,6 +124,8 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
void
SetDeviceCommContext
(
::
ir
::
Operation
*
op
,
void
SetDeviceCommContext
(
::
ir
::
Operation
*
op
,
platform
::
DeviceContext
*
dev_ctx
);
platform
::
DeviceContext
*
dev_ctx
);
std
::
unordered_set
<
std
::
string
>
GetSpecialOpNames
();
}
// namespace interpreter
}
// namespace interpreter
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/new_ir_interpreter.cc
浏览文件 @
a25f013b
...
@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
...
@@ -219,10 +219,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
&
value_2_var_name_
,
&
value_2_var_name_
,
&
variable_2_var_name_
,
&
variable_2_var_name_
,
&
var_name_2_id_
,
&
var_name_2_id_
,
&
variable_list_
,
&
variable_list_
);
&
parameter_values_
);
VLOG
(
4
)
<<
DebugValueInfo
();
VLOG
(
4
)
<<
DebugValueInfo
();
SolvePersisableVarNames
();
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
op_func_nodes
;
interpreter
::
BuildOpFuncList
(
place_
,
interpreter
::
BuildOpFuncList
(
place_
,
ir_program_
->
block
(),
ir_program_
->
block
(),
...
@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
...
@@ -1595,7 +1596,9 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace(
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"trace order: "
;
ss
<<
"trace order: "
;
for
(
size_t
idx
=
0
;
idx
<
trace_execute_order_
.
size
();
idx
++
)
{
for
(
size_t
idx
=
0
;
idx
<
trace_execute_order_
.
size
();
idx
++
)
{
ss
<<
trace_execute_order_
[
idx
]
<<
" -> "
;
ss
<<
vec_instruction_base_
[
trace_execute_order_
[
idx
]]
->
Name
()
<<
"["
<<
trace_execute_order_
[
idx
]
<<
"]"
<<
" -> "
;
}
}
ss
<<
"end
\n
"
;
ss
<<
"end
\n
"
;
VLOG
(
6
)
<<
ss
.
str
();
VLOG
(
6
)
<<
ss
.
str
();
...
@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
...
@@ -1616,10 +1619,7 @@ void NewIRInterpreter::BuildInstruction() {
.
at
(
"op_name"
)
.
at
(
"op_name"
)
.
dyn_cast
<::
ir
::
StrAttribute
>
()
.
dyn_cast
<::
ir
::
StrAttribute
>
()
.
AsString
();
.
AsString
();
if
(
op_name
==
"builtin.combine"
||
op_name
==
"pd.feed"
||
if
(
interpreter
::
GetSpecialOpNames
().
count
(
op_name
))
{
op_name
==
"builtin.set_parameter"
||
op_name
==
"builtin.get_parameter"
||
op_name
==
"builtin.slice"
||
op_name
==
"pd.data"
||
op_name
==
"pd.shadow_output"
)
{
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
VLOG
(
6
)
<<
"skip process "
<<
op_name
;
continue
;
continue
;
}
}
...
@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
...
@@ -1793,18 +1793,8 @@ void NewIRInterpreter::RecordStreamForGC(InstructionBase* instr) {
VLOG
(
4
)
<<
"GC sync "
<<
GetNameById
(
var_id
);
VLOG
(
4
)
<<
"GC sync "
<<
GetNameById
(
var_id
);
// persistable var will be ignore while GC
// persistable var will be ignore while GC
::
ir
::
Value
value
=
GetValueByName
(
GetNameById
(
var_id
));
if
(
parameter_var_names_
.
count
(
GetNameById
(
var_id
)))
{
bool
is_parameter
=
false
;
VLOG
(
4
)
<<
GetNameById
(
var_id
)
<<
" is a parameter, skip gc"
;
if
(
value
)
{
for
(
auto
item
:
parameter_values_
)
{
if
(
item
==
value
)
{
is_parameter
=
true
;
break
;
}
}
}
if
(
is_parameter
)
{
VLOG
(
4
)
<<
"value "
<<
value
.
impl
()
<<
" is a parameter, skip gc"
;
continue
;
continue
;
}
}
...
@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
...
@@ -1851,18 +1841,8 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
<<
", ref:"
<<
refs_
[
var_id
]
->
DynamicRef
();
<<
", ref:"
<<
refs_
[
var_id
]
->
DynamicRef
();
bool
is_ready
=
refs_
[
var_id
]
->
CheckAndDecrease
();
bool
is_ready
=
refs_
[
var_id
]
->
CheckAndDecrease
();
// ignore all persistable var while GCphi
// ignore all persistable var while GCphi
::
ir
::
Value
value
=
GetValueByName
(
GetNameById
(
var_id
));
if
(
parameter_var_names_
.
count
(
GetNameById
(
var_id
)))
{
bool
is_parameter
=
false
;
VLOG
(
4
)
<<
GetNameById
(
var_id
)
<<
" is a parameter, skip gc"
;
if
(
value
)
{
for
(
auto
item
:
parameter_values_
)
{
if
(
item
==
value
)
{
is_parameter
=
true
;
break
;
}
}
}
if
(
is_parameter
)
{
VLOG
(
4
)
<<
"value "
<<
value
.
impl
()
<<
" is a parameter, skip gc"
;
continue
;
continue
;
}
}
...
@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
...
@@ -2020,11 +2000,17 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&
value_2_var_name_
,
&
value_2_var_name_
,
&
variable_2_var_name_
,
&
variable_2_var_name_
,
&
var_name_2_id_
,
&
var_name_2_id_
,
&
variable_list_
,
&
variable_list_
);
&
parameter_values_
);
VLOG
(
4
)
<<
"Done BuildScope"
;
VLOG
(
4
)
<<
"Done BuildScope"
;
VLOG
(
4
)
<<
DebugValueInfo
();
VLOG
(
4
)
<<
DebugValueInfo
();
SolvePersisableVarNames
();
VLOG
(
4
)
<<
"Parameter value include: "
;
for
(
auto
parameter
:
parameter_var_names_
)
{
VLOG
(
4
)
<<
"Parameter value: "
<<
parameter
;
}
BuildInstruction
();
BuildInstruction
();
VLOG
(
4
)
<<
"Done BuildInstruction"
;
VLOG
(
4
)
<<
"Done BuildInstruction"
;
...
@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
...
@@ -2032,15 +2018,9 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
VLOG
(
4
)
<<
"Done PreAnalysis"
;
VLOG
(
4
)
<<
"Done PreAnalysis"
;
// Run
// Run
if
(
FLAGS_enable_new_ir_in_executor_loop_run
)
{
LOG_FIRST_N
(
INFO
,
1
)
<<
"New ir interpreter is running in BetaRun mode "
LOG_FIRST_N
(
INFO
,
1
)
<<
"New ir interpreter is running in BetaRun mode "
"with for_loop version(First step)."
;
"with for_loop version."
;
LoopRunImpl
();
LoopRunImpl
();
}
else
{
LOG_FIRST_N
(
INFO
,
1
)
<<
"New ir interpreter is running in BetaRun mode "
"with trace version."
;
TraceRunImpl
();
}
is_build_
=
true
;
is_build_
=
true
;
}
else
{
}
else
{
if
(
FLAGS_enable_new_ir_in_executor_loop_run
)
{
if
(
FLAGS_enable_new_ir_in_executor_loop_run
)
{
...
@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList(
...
@@ -2177,7 +2157,8 @@ void NewIRInterpreter::TraceRunInstructionList(
auto
instr_id
=
trace_execute_order_
[
idx
];
auto
instr_id
=
trace_execute_order_
[
idx
];
InstructionBase
*
instr_node
=
vec_instruction_base_
.
at
(
instr_id
).
get
();
InstructionBase
*
instr_node
=
vec_instruction_base_
.
at
(
instr_id
).
get
();
VLOG
(
6
)
<<
"Run InstructionBase "
<<
instr_id
;
VLOG
(
6
)
<<
"Run InstructionBase "
<<
instr_node
->
Name
()
<<
"["
<<
instr_id
<<
"]"
;
RunInstructionBase
(
instr_node
);
RunInstructionBase
(
instr_node
);
if
(
UNLIKELY
(
exception_holder_
.
IsCaught
()))
{
if
(
UNLIKELY
(
exception_holder_
.
IsCaught
()))
{
...
@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() {
...
@@ -2263,5 +2244,26 @@ void NewIRInterpreter::PreAnalysis() {
return
nullptr
;
return
nullptr
;
}
}
void
NewIRInterpreter
::
SolvePersisableVarNames
()
{
VLOG
(
6
)
<<
"SolvePersisableVarNames"
;
for
(
auto
kv
:
value_2_var_name_
)
{
::
ir
::
Value
value
=
kv
.
first
;
std
::
string
var_name
=
kv
.
second
;
::
ir
::
OpResult
result
=
value
.
dyn_cast
<::
ir
::
OpResult
>
();
auto
*
defining_op
=
value
.
GetDefiningOp
();
if
(
defining_op
->
HasAttribute
(
kAttrIsPersisable
))
{
auto
is_persisables
=
defining_op
->
attribute
(
kAttrIsPersisable
)
.
dyn_cast
<::
ir
::
ArrayAttribute
>
()
.
AsVector
();
if
(
is_persisables
[
result
.
GetResultIndex
()]
.
dyn_cast
<::
ir
::
BoolAttribute
>
()
.
data
())
{
VLOG
(
6
)
<<
"parameter_var_names_ include: "
<<
var_name
;
parameter_var_names_
.
insert
(
var_name
);
}
}
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/new_executor/new_ir_interpreter.h
浏览文件 @
a25f013b
...
@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
...
@@ -235,6 +235,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
void
RecordStreamForGC
(
InstructionBase
*
instr
);
void
RecordStreamForGC
(
InstructionBase
*
instr
);
void
SolvePersisableVarNames
();
InstructionSchedulingPriorityLess
ir_instruction_scheduling_priority_less
;
InstructionSchedulingPriorityLess
ir_instruction_scheduling_priority_less
;
std
::
unique_ptr
<::
ir
::
Program
>
ir_program_
{
nullptr
};
std
::
unique_ptr
<::
ir
::
Program
>
ir_program_
{
nullptr
};
...
@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
...
@@ -260,7 +262,7 @@ class NewIRInterpreter : public InterpreterBaseImpl {
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// Note(zhangbo): set_parameter_op's input and get_parameter_op's output
// belongs to a parameter and cannot GC.
// belongs to a parameter and cannot GC.
std
::
vector
<::
ir
::
Value
>
parameter_valu
es_
;
std
::
unordered_set
<
std
::
string
>
parameter_var_nam
es_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/new_executor/program_interpreter.cc
浏览文件 @
a25f013b
...
@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
...
@@ -1503,6 +1503,18 @@ void ProgramInterpreter::AnalyseExecuteOrderForTrace() {
"trace_order size should be equal to dependecy_count_."
));
"trace_order size should be equal to dependecy_count_."
));
trace_execute_order_
=
trace_order
;
trace_execute_order_
=
trace_order
;
std
::
stringstream
ss
;
ss
<<
"trace order: "
;
for
(
size_t
idx
=
0
;
idx
<
trace_execute_order_
.
size
();
idx
++
)
{
ss
<<
vec_instruction_
[
trace_execute_order_
[
idx
]]
.
OpFunc
()
->
operator_base_
->
Type
()
<<
"["
<<
trace_execute_order_
[
idx
]
<<
"]"
<<
" -> "
;
}
ss
<<
"end
\n
"
;
VLOG
(
6
)
<<
ss
.
str
();
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/ir/interface/op_yaml_info_parser.cc
浏览文件 @
a25f013b
...
@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName(
...
@@ -118,6 +118,28 @@ const std::string& OpYamlInfoParser::InplaceName(
"Can not find inplace input of [%s]."
,
out_name
));
"Can not find inplace input of [%s]."
,
out_name
));
}
}
bool
OpYamlInfoParser
::
HasView
(
const
std
::
string
&
out_name
)
const
{
auto
&
view_info
=
std
::
get
<
3
>
(
op_info_tuple_
).
view
;
for
(
size_t
i
=
0
;
i
<
view_info
.
size
();
i
++
)
{
if
(
out_name
==
view_info
[
i
].
first
)
{
return
true
;
}
}
return
false
;
}
const
std
::
string
&
OpYamlInfoParser
::
ViewName
(
const
std
::
string
&
out_name
)
const
{
auto
&
view_info
=
std
::
get
<
3
>
(
op_info_tuple_
).
view
;
for
(
size_t
i
=
0
;
i
<
view_info
.
size
();
i
++
)
{
if
(
out_name
==
view_info
[
i
].
first
)
{
return
view_info
[
i
].
second
;
}
}
PADDLE_THROW
(
phi
::
errors
::
PreconditionNotMet
(
"Can not find inplace input of [%s]."
,
out_name
));
}
void
OpYamlInfoParser
::
parse
()
{
void
OpYamlInfoParser
::
parse
()
{
auto
input_info
=
std
::
get
<
0
>
(
op_info_tuple_
);
auto
input_info
=
std
::
get
<
0
>
(
op_info_tuple_
);
...
...
paddle/fluid/ir/interface/op_yaml_info_parser.h
浏览文件 @
a25f013b
...
@@ -53,6 +53,10 @@ class OpYamlInfoParser {
...
@@ -53,6 +53,10 @@ class OpYamlInfoParser {
const
std
::
string
&
InplaceName
(
const
std
::
string
&
out_name
)
const
;
const
std
::
string
&
InplaceName
(
const
std
::
string
&
out_name
)
const
;
bool
HasView
(
const
std
::
string
&
out_name
)
const
;
const
std
::
string
&
ViewName
(
const
std
::
string
&
out_name
)
const
;
private:
private:
void
parse
();
void
parse
();
inline
const
std
::
vector
<
OpInputInfo
>&
InputInfo
()
const
{
inline
const
std
::
vector
<
OpInputInfo
>&
InputInfo
()
const
{
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
浏览文件 @
a25f013b
...
@@ -69,8 +69,7 @@ class PhiKernelAdaptor {
...
@@ -69,8 +69,7 @@ class PhiKernelAdaptor {
&
value_2_var_name
,
&
value_2_var_name
,
&
variable_2_var_name
,
&
variable_2_var_name
,
&
var_name_2_id
,
&
var_name_2_id
,
&
variable_list
,
&
variable_list
);
nullptr
);
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ir
::
IrContext
*
ctx
=
ir
::
IrContext
::
Instance
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
ctx
->
GetOrRegisterDialect
<
paddle
::
dialect
::
PaddleDialect
>
();
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc
浏览文件 @
a25f013b
...
@@ -217,8 +217,7 @@ void HandleForSpecialOp(
...
@@ -217,8 +217,7 @@ void HandleForSpecialOp(
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
)
{
std
::
vector
<::
ir
::
Value
>*
parameter_values
)
{
std
::
string
op_name
=
op
->
name
();
std
::
string
op_name
=
op
->
name
();
if
(
op
->
attributes
().
count
(
"op_name"
))
{
if
(
op
->
attributes
().
count
(
"op_name"
))
{
op_name
=
op_name
=
...
@@ -347,10 +346,6 @@ void HandleForSpecialOp(
...
@@ -347,10 +346,6 @@ void HandleForSpecialOp(
value_2_var_name
,
value_2_var_name
,
variable_2_var_name
,
variable_2_var_name
,
var_name_2_id
);
var_name_2_id
);
if
(
parameter_values
)
{
parameter_values
->
push_back
(
value
);
}
}
}
if
(
op_name
==
"pd.shadow_output"
)
{
if
(
op_name
==
"pd.shadow_output"
)
{
...
@@ -390,10 +385,6 @@ void HandleForSpecialOp(
...
@@ -390,10 +385,6 @@ void HandleForSpecialOp(
variable_2_var_name
,
variable_2_var_name
,
var_name_2_id
,
var_name_2_id
,
variable_list
);
variable_list
);
if
(
parameter_values
)
{
parameter_values
->
push_back
(
value
);
}
}
}
if
(
op_name
==
"builtin.slice"
)
{
if
(
op_name
==
"builtin.slice"
)
{
...
@@ -458,6 +449,14 @@ void HandleForInplaceOp(
...
@@ -458,6 +449,14 @@ void HandleForInplaceOp(
VLOG
(
4
)
<<
"inplace: "
<<
value_name
<<
" -> "
<<
inplace_name
VLOG
(
4
)
<<
"inplace: "
<<
value_name
<<
" -> "
<<
inplace_name
<<
" (var: "
<<
var_name
<<
")"
;
<<
" (var: "
<<
var_name
<<
")"
;
value_2_var_name
->
emplace
(
value
,
var_name
);
value_2_var_name
->
emplace
(
value
,
var_name
);
}
else
if
(
yaml_parser
.
HasView
(
value_name
))
{
std
::
string
view_name
=
yaml_parser
.
ViewName
(
value_name
);
ir
::
Value
view_value
=
op
->
operand_source
(
yaml_parser
.
InputName2Id
().
at
(
view_name
));
std
::
string
var_name
=
value_2_var_name
->
at
(
view_value
);
VLOG
(
4
)
<<
"view: "
<<
value_name
<<
" -> "
<<
view_name
<<
" (var: "
<<
var_name
<<
")"
;
value_2_var_name
->
emplace
(
value
,
var_name
);
}
else
{
}
else
{
BuildValue
(
value
,
BuildValue
(
value
,
inner_scope
,
inner_scope
,
...
@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block,
...
@@ -479,8 +478,7 @@ void BuildScope(const ir::Block& block,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
std
::
string
>*
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
)
{
std
::
vector
<::
ir
::
Value
>*
parameter_values
)
{
VLOG
(
4
)
<<
"***** [before build] scope"
VLOG
(
4
)
<<
"***** [before build] scope"
<<
"("
<<
inner_scope
<<
") ******
\n
"
<<
"("
<<
inner_scope
<<
") ******
\n
"
<<
paddle
::
framework
::
GenScopeTreeDebugInfo
(
<<
paddle
::
framework
::
GenScopeTreeDebugInfo
(
...
@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block,
...
@@ -506,8 +504,7 @@ void BuildScope(const ir::Block& block,
value_2_var_name
,
value_2_var_name
,
variable_2_var_name
,
variable_2_var_name
,
var_name_2_id
,
var_name_2_id
,
variable_list
,
variable_list
);
parameter_values
);
continue
;
continue
;
}
}
...
...
paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
浏览文件 @
a25f013b
...
@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block,
...
@@ -49,8 +49,7 @@ void BuildScope(const ir::Block& block,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
unordered_map
<
const
paddle
::
framework
::
Variable
*
,
std
::
string
>*
variable_2_var_name
,
std
::
string
>*
variable_2_var_name
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
map
<
std
::
string
,
int
>*
var_name_2_id
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
,
std
::
vector
<
paddle
::
framework
::
Variable
*>*
variable_list
);
std
::
vector
<::
ir
::
Value
>*
parameter_values
);
void
BuildRuntimeContext
(
void
BuildRuntimeContext
(
ir
::
Operation
*
op
,
ir
::
Operation
*
op
,
...
@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op,
...
@@ -288,23 +287,29 @@ void BuildPhiContext(ir::Operation* op,
// TODO(phlrain): use var type instead of op name
// TODO(phlrain): use var type instead of op name
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
num_results
();
++
i
)
{
ir
::
Value
out_ptr
=
op
->
result
(
i
);
ir
::
Value
out_ptr
=
op
->
result
(
i
);
auto
name
=
name_map
.
at
(
out_ptr
);
VLOG
(
6
)
<<
"ctx->EmplaceBackOutput: "
<<
name
;
auto
out_type
=
out_ptr
.
type
();
auto
out_type
=
out_ptr
.
type
();
if
(
out_type
)
{
auto
name
=
name_map
.
at
(
out_ptr
);
VLOG
(
6
)
<<
"ctx->EmplaceBackOutput: "
<<
name
;
}
else
{
VLOG
(
6
)
<<
"ctx->EmplaceBackOutput : an optioanl output"
;
}
if
(
!
out_type
)
{
if
(
!
out_type
)
{
phi
::
DenseTensor
*
ptr
=
nullptr
;
phi
::
DenseTensor
*
ptr
=
nullptr
;
OutType
out_ptr
(
ptr
);
OutType
out_ptr
(
ptr
);
ctx
->
EmplaceBackOutput
(
out_ptr
);
ctx
->
EmplaceBackOutput
(
out_ptr
);
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedDenseTensorType
>
())
{
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedDenseTensorType
>
())
{
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
inner_scope
->
FindVar
(
name
)
->
Get
<
phi
::
DenseTensor
>
()))));
&
(
inner_scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
->
Get
<
phi
::
DenseTensor
>
()))));
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedSelectedRowsType
>
())
{
}
else
if
(
out_type
.
isa
<
paddle
::
dialect
::
AllocatedSelectedRowsType
>
())
{
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
SelectedRows
*>
(
ctx
->
EmplaceBackOutput
(
OutType
(
const_cast
<
phi
::
SelectedRows
*>
(
&
(
inner_scope
->
FindVar
(
name
)
->
Get
<
phi
::
SelectedRows
>
()))));
&
(
inner_scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
->
Get
<
phi
::
SelectedRows
>
()))));
}
else
if
(
out_type
.
isa
<
ir
::
VectorType
>
())
{
}
else
if
(
out_type
.
isa
<
ir
::
VectorType
>
())
{
OutListType
outputs
;
OutListType
outputs
;
auto
&
variable_array
=
auto
&
variable_array
=
scope
->
FindVar
(
name_map
.
at
(
out_ptr
))
scope
->
FindVar
(
name
)
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
->
Get
<
paddle
::
framework
::
VariableRefArray
>
();
for
(
size_t
i
=
0
;
i
<
variable_array
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
variable_array
.
size
();
++
i
)
{
outputs
.
emplace_back
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
outputs
.
emplace_back
(
OutType
(
const_cast
<
phi
::
DenseTensor
*>
(
&
(
variable_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
()))));
&
(
variable_array
[
i
]
->
Get
<
phi
::
DenseTensor
>
()))));
...
...
paddle/fluid/ir_adaptor/translator/op_translator.cc
浏览文件 @
a25f013b
...
@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute(
...
@@ -72,9 +72,7 @@ using AttributeHandlerFn = std::function<ir::Attribute(
constexpr
char
kTargetDialectPrefix
[]
=
"pd."
;
constexpr
char
kTargetDialectPrefix
[]
=
"pd."
;
constexpr
char
kEmptyVarName
[]
=
"@EMPTY@"
;
constexpr
char
kEmptyVarName
[]
=
"@EMPTY@"
;
static
const
std
::
unordered_set
<
std
::
string
>
special_non_inplace_ops
=
{
static
const
std
::
unordered_set
<
std
::
string
>
special_non_inplace_ops
=
{};
"batch_norm"
,
};
static
const
std
::
unordered_set
<
std
::
string
>
special_inplace_ops
=
{
static
const
std
::
unordered_set
<
std
::
string
>
special_inplace_ops
=
{
"adagrad"
,
"adagrad"
,
...
...
paddle/fluid/ir_adaptor/translator/program_translator.cc
浏览文件 @
a25f013b
...
@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() {
...
@@ -77,6 +77,11 @@ void ProgramTranslator::Translate() {
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetStopGradientAttributeForAllValue
(
block
);
SetStopGradientAttributeForAllValue
(
block
);
}
}
for
(
size_t
block_idx
=
0
;
block_idx
<
legacy_program_
->
Size
();
block_idx
++
)
{
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetIsPersisableAttributeForAllValue
(
block
);
}
}
}
inline
ir
::
Operation
*
InsertGetParamaterOp
(
ir
::
IrContext
*
ctx
,
inline
ir
::
Operation
*
InsertGetParamaterOp
(
ir
::
IrContext
*
ctx
,
...
@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
...
@@ -268,5 +273,44 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
}
}
}
}
void
ProgramTranslator
::
SetIsPersisableAttributeForAllValue
(
const
BlockDesc
&
block
)
{
// Currently we set is persisable for operation that generated a value
// connected with VarDesc
for
(
const
auto
&
[
var_name
,
value_info
]
:
param_map_
)
{
if
(
no_cast_var_names
.
count
(
var_name
)
!=
0
)
continue
;
VLOG
(
10
)
<<
"[op translated][is persisable]"
<<
var_name
;
VarDesc
*
var
=
block
.
FindVarRecursive
(
var_name
);
if
(
var
==
nullptr
)
{
continue
;
}
ir
::
OpResult
value
=
value_info
.
value
;
if
(
!
value
)
{
PADDLE_THROW
(
phi
::
errors
::
PreconditionNotMet
(
"Value of [%s] can not ber None"
,
var_name
));
}
auto
*
defining_op
=
value
.
owner
();
PADDLE_ENFORCE_NOT_NULL
(
defining_op
,
phi
::
errors
::
PreconditionNotMet
(
"Defining operator of [%s] can not be nullptr"
,
var_name
));
VLOG
(
8
)
<<
"[op translated][is persisable]"
<<
var_name
<<
" from: "
<<
defining_op
->
name
();
std
::
vector
<
ir
::
Attribute
>
is_persisable
;
if
(
defining_op
->
HasAttribute
(
kAttrIsPersisable
))
{
is_persisable
=
defining_op
->
attribute
(
kAttrIsPersisable
)
.
dyn_cast
<
ir
::
ArrayAttribute
>
()
.
AsVector
();
}
else
{
is_persisable
=
std
::
vector
<
ir
::
Attribute
>
(
defining_op
->
num_results
(),
ir
::
BoolAttribute
::
get
(
ctx_
,
false
));
}
is_persisable
[
value
.
GetResultIndex
()]
=
ir
::
BoolAttribute
::
get
(
ctx_
,
var
->
Persistable
());
defining_op
->
set_attribute
(
kAttrIsPersisable
,
ir
::
ArrayAttribute
::
get
(
ctx_
,
is_persisable
));
}
}
}
// namespace translator
}
// namespace translator
}
// namespace paddle
}
// namespace paddle
paddle/fluid/ir_adaptor/translator/program_translator.h
浏览文件 @
a25f013b
...
@@ -79,6 +79,7 @@ class ProgramTranslator {
...
@@ -79,6 +79,7 @@ class ProgramTranslator {
void
InsertOperationToSingleBlock
(
const
BlockDesc
&
block
);
void
InsertOperationToSingleBlock
(
const
BlockDesc
&
block
);
void
SetParameterFromSingleBlock
(
const
BlockDesc
&
block
);
void
SetParameterFromSingleBlock
(
const
BlockDesc
&
block
);
void
SetStopGradientAttributeForAllValue
(
const
BlockDesc
&
block
);
void
SetStopGradientAttributeForAllValue
(
const
BlockDesc
&
block
);
void
SetIsPersisableAttributeForAllValue
(
const
BlockDesc
&
block
);
};
};
}
// namespace translator
}
// namespace translator
...
...
paddle/ir/core/attribute.h
浏览文件 @
a25f013b
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "paddle/ir/core/type_id.h"
#include "paddle/ir/core/type_id.h"
constexpr
char
kAttrStopGradients
[]
=
"stop_gradient"
;
constexpr
char
kAttrStopGradients
[]
=
"stop_gradient"
;
constexpr
char
kAttrIsPersisable
[]
=
"is_persisable"
;
namespace
ir
{
namespace
ir
{
class
AttributeStorage
;
class
AttributeStorage
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录