Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c5a191bb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c5a191bb
编写于
7月 10, 2023
作者:
K
kangguangli
提交者:
GitHub
7月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NewIR] add stop_gradient attribute for defining op (#55235)
* add stop_gradient attribute for defining op * modify by reviews * fix
上级
4905a247
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
63 addition
and
14 deletion
+63
-14
paddle/fluid/ir_adaptor/translator/op_translator.cc
paddle/fluid/ir_adaptor/translator/op_translator.cc
+4
-5
paddle/fluid/ir_adaptor/translator/program_translator.cc
paddle/fluid/ir_adaptor/translator/program_translator.cc
+43
-5
paddle/fluid/ir_adaptor/translator/program_translator.h
paddle/fluid/ir_adaptor/translator/program_translator.h
+2
-1
paddle/ir/core/operation.cc
paddle/ir/core/operation.cc
+5
-0
paddle/ir/core/operation.h
paddle/ir/core/operation.h
+7
-1
test/cpp/ir/core/ir_op_test.cc
test/cpp/ir/core/ir_op_test.cc
+2
-2
未找到文件。
paddle/fluid/ir_adaptor/translator/op_translator.cc
浏览文件 @
c5a191bb
...
@@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
...
@@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std
::
set
<
std
::
string
>
yaml_input_set
;
std
::
set
<
std
::
string
>
yaml_input_set
;
for
(
const
auto
&
info
:
input_infos
)
{
for
(
const
auto
&
info
:
input_infos
)
{
if
(
auto
special_handler
=
this
->
GetSpecialInputHandlers
(
info
.
name
))
{
continue
;
}
std
::
string
legacy_input_name
=
std
::
string
legacy_input_name
=
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
op_normalizer
.
GetLegacyArgName
(
op_desc
.
Type
(),
info
.
name
);
...
@@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
...
@@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
std
::
vector
<
std
::
string
>
legacy_input_vars
;
std
::
vector
<
std
::
string
>
legacy_input_vars
;
// return empty OpResult if this arg is optional and not shown in OpDesc
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if
(
op_desc
.
HasInput
(
legacy_input_name
,
true
))
{
if
(
op_desc
.
HasInput
(
legacy_input_name
,
true
))
{
legacy_input_vars
=
op_desc
.
Input
(
legacy_input_name
,
true
);
legacy_input_vars
=
op_desc
.
Input
(
legacy_input_name
,
true
);
}
}
...
@@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
...
@@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(
// if src type is Tensor
// if src type is Tensor
if
(
!
is_vector
)
{
if
(
!
is_vector
)
{
IR_ENFORCE
(
legacy_input_vars
.
size
()
==
1u
,
"Input %s not found when parsing op %s"
,
info
.
name
,
op_desc
.
Type
());
auto
defining_info
=
(
*
param_map
)[
legacy_input_vars
[
0
]];
auto
defining_info
=
(
*
param_map
)[
legacy_input_vars
[
0
]];
op_inputs
.
push_back
(
defining_info
.
value
);
op_inputs
.
push_back
(
defining_info
.
value
);
...
...
paddle/fluid/ir_adaptor/translator/program_translator.cc
浏览文件 @
c5a191bb
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/enforce.h"
...
@@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc;
...
@@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc;
using
BlockDesc
=
::
paddle
::
framework
::
BlockDesc
;
using
BlockDesc
=
::
paddle
::
framework
::
BlockDesc
;
using
VarDesc
=
::
paddle
::
framework
::
VarDesc
;
using
VarDesc
=
::
paddle
::
framework
::
VarDesc
;
const
std
::
unordered_set
<
std
::
string
>
ProgramTranslator
::
no_cast_var_names
=
{
"feed"
,
"fetch"
,
};
constexpr
char
kAttrStopGradients
[]
=
"stop_gradient"
;
ProgramTranslator
::
ProgramTranslator
(
const
ProgramDesc
*
legacy_program
,
ProgramTranslator
::
ProgramTranslator
(
const
ProgramDesc
*
legacy_program
,
ir
::
Program
*
program
)
ir
::
Program
*
program
)
:
legacy_program_
(
legacy_program
),
program_
(
program
)
{
:
legacy_program_
(
legacy_program
),
program_
(
program
)
{
ctx_
=
ir
::
IrContext
::
Instance
();
ctx_
=
ir
::
IrContext
::
Instance
();
}
}
const
std
::
unordered_set
<
std
::
string
>
ProgramTranslator
::
no_cast_var_names
=
{
"feed"
,
"fetch"
,
};
void
ProgramTranslator
::
Translate
()
{
void
ProgramTranslator
::
Translate
()
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
legacy_program_
->
Size
(),
legacy_program_
->
Size
(),
...
@@ -71,6 +74,11 @@ void ProgramTranslator::Translate() {
...
@@ -71,6 +74,11 @@ void ProgramTranslator::Translate() {
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetParameterFromSingleBlock
(
block
);
SetParameterFromSingleBlock
(
block
);
}
}
for
(
size_t
block_idx
=
0
;
block_idx
<
legacy_program_
->
Size
();
block_idx
++
)
{
const
BlockDesc
&
block
=
legacy_program_
->
Block
(
block_idx
);
SetStopGradientAttributeForAllValue
(
block
);
}
}
}
inline
ir
::
Operation
*
InsertGetParamaterOp
(
ir
::
IrContext
*
ctx
,
inline
ir
::
Operation
*
InsertGetParamaterOp
(
ir
::
IrContext
*
ctx
,
...
@@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
...
@@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
}
}
}
}
void
ProgramTranslator
::
SetStopGradientAttributeForAllValue
(
const
BlockDesc
&
block
)
{
// Currently we set stop gradient for operation that generated a value
// connected with VarDesc
for
(
const
auto
&
[
var_name
,
value_info
]
:
param_map_
)
{
VLOG
(
10
)
<<
"[op translated][stop gradient]"
<<
var_name
;
VarDesc
*
var
=
block
.
FindVarRecursive
(
var_name
);
if
(
var
==
nullptr
)
{
continue
;
}
ir
::
OpResult
value
=
value_info
.
value
;
auto
*
defining_op
=
value
.
owner
();
VLOG
(
8
)
<<
"[op translated][stop gradient]"
<<
var_name
<<
" from: "
<<
defining_op
->
name
();
std
::
vector
<
ir
::
Attribute
>
stop_gradients
;
if
(
defining_op
->
HasAttribute
(
kAttrStopGradients
))
{
stop_gradients
=
defining_op
->
attribute
(
kAttrStopGradients
)
.
dyn_cast
<
ir
::
ArrayAttribute
>
()
.
data
();
}
else
{
stop_gradients
=
std
::
vector
<
ir
::
Attribute
>
(
defining_op
->
num_results
(),
ir
::
BoolAttribute
::
get
(
ctx_
,
false
));
}
stop_gradients
[
value
.
GetResultIndex
()]
=
ir
::
BoolAttribute
::
get
(
ctx_
,
var
->
StopGradient
());
defining_op
->
set_attribute
(
kAttrStopGradients
,
ir
::
ArrayAttribute
::
get
(
ctx_
,
stop_gradients
));
}
}
}
// namespace translator
}
// namespace translator
}
// namespace paddle
}
// namespace paddle
paddle/fluid/ir_adaptor/translator/program_translator.h
浏览文件 @
c5a191bb
...
@@ -72,12 +72,13 @@ class ProgramTranslator {
...
@@ -72,12 +72,13 @@ class ProgramTranslator {
/// 2. "fetch", the output variable of fetch op
/// 2. "fetch", the output variable of fetch op
/// However, new feed has no input and new fetch has no output
/// However, new feed has no input and new fetch has no output
/// So we don't handle these two vairables when
/// So we don't handle these two vairables when
/// `
Extrac
tParameterFromSingleBlock`
/// `
Get/Se
tParameterFromSingleBlock`
static
const
std
::
unordered_set
<
std
::
string
>
no_cast_var_names
;
static
const
std
::
unordered_set
<
std
::
string
>
no_cast_var_names
;
void
GetParameterForSingleBlock
(
const
BlockDesc
&
block
);
void
GetParameterForSingleBlock
(
const
BlockDesc
&
block
);
void
InsertOperationToSingleBlock
(
const
BlockDesc
&
block
);
void
InsertOperationToSingleBlock
(
const
BlockDesc
&
block
);
void
SetParameterFromSingleBlock
(
const
BlockDesc
&
block
);
void
SetParameterFromSingleBlock
(
const
BlockDesc
&
block
);
void
SetStopGradientAttributeForAllValue
(
const
BlockDesc
&
block
);
};
};
}
// namespace translator
}
// namespace translator
...
...
paddle/ir/core/operation.cc
浏览文件 @
c5a191bb
...
@@ -205,6 +205,11 @@ std::string Operation::name() const {
...
@@ -205,6 +205,11 @@ std::string Operation::name() const {
return
p_name
?
p_name
:
""
;
return
p_name
?
p_name
:
""
;
}
}
Attribute
Operation
::
attribute
(
const
std
::
string
&
key
)
const
{
IR_ENFORCE
(
HasAttribute
(
key
),
"operation(%s): no attribute %s"
,
name
(),
key
);
return
attributes_
.
at
(
key
);
}
Region
*
Operation
::
GetParentRegion
()
const
{
Region
*
Operation
::
GetParentRegion
()
const
{
return
parent_
?
parent_
->
GetParent
()
:
nullptr
;
return
parent_
?
parent_
->
GetParent
()
:
nullptr
;
}
}
...
...
paddle/ir/core/operation.h
浏览文件 @
c5a191bb
...
@@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final {
...
@@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final {
const
AttributeMap
&
attributes
()
const
{
return
attributes_
;
}
const
AttributeMap
&
attributes
()
const
{
return
attributes_
;
}
void
SetA
ttribute
(
const
std
::
string
&
key
,
Attribute
value
)
{
void
set_a
ttribute
(
const
std
::
string
&
key
,
Attribute
value
)
{
attributes_
[
key
]
=
value
;
attributes_
[
key
]
=
value
;
}
}
Attribute
attribute
(
const
std
::
string
&
key
)
const
;
bool
HasAttribute
(
const
std
::
string
&
key
)
const
{
return
attributes_
.
find
(
key
)
!=
attributes_
.
end
();
}
ir
::
OpInfo
info
()
const
{
return
info_
;
}
ir
::
OpInfo
info
()
const
{
return
info_
;
}
uint32_t
num_results
()
const
{
return
num_results_
;
}
uint32_t
num_results
()
const
{
return
num_results_
;
}
...
...
test/cpp/ir/core/ir_op_test.cc
浏览文件 @
c5a191bb
...
@@ -274,6 +274,6 @@ TEST(op_test, module_op_death) {
...
@@ -274,6 +274,6 @@ TEST(op_test, module_op_death) {
EXPECT_EQ
(
program
.
module_op
().
program
(),
&
program
);
EXPECT_EQ
(
program
.
module_op
().
program
(),
&
program
);
EXPECT_EQ
(
program
.
module_op
().
ir_context
(),
ctx
);
EXPECT_EQ
(
program
.
module_op
().
ir_context
(),
ctx
);
program
.
module_op
()
->
SetA
ttribute
(
"program"
,
program
.
module_op
()
->
set_a
ttribute
(
"program"
,
ir
::
PointerAttribute
::
get
(
ctx
,
&
program
));
ir
::
PointerAttribute
::
get
(
ctx
,
&
program
));
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录