Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
778ea4ec
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
778ea4ec
编写于
5月 09, 2022
作者:
C
Chen Weihang
提交者:
GitHub
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Polish grad code details (#42536)
* polish grad details * polish detail by comment
上级
13bcb7cd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
128 addition
and
133 deletion
+128
-133
paddle/fluid/eager/backward.cc
paddle/fluid/eager/backward.cc
+72
-71
paddle/fluid/eager/grad_node_info.cc
paddle/fluid/eager/grad_node_info.cc
+26
-1
paddle/fluid/eager/grad_node_info.h
paddle/fluid/eager/grad_node_info.h
+22
-53
paddle/fluid/eager/tensor_wrapper.h
paddle/fluid/eager/tensor_wrapper.h
+8
-8
未找到文件。
paddle/fluid/eager/backward.cc
浏览文件 @
778ea4ec
...
...
@@ -66,68 +66,69 @@ class GeneralGrad {
"stop_gradient=True."
,
msg
,
i
));
if
(
is_no_grad_vars
)
{
(
no_grad_var_nodes_inputmeta_map
)[
target_node
]
=
auto_grad_meta
;
(
no_grad_var_nodes_inputmeta_map
_
)[
target_node
]
=
auto_grad_meta
;
}
else
{
// normal input
(
input_target_nodes_inputmeta_map
)[
target_node
]
=
auto_grad_meta
;
(
input_target_nodes_inputmeta_map
_
)[
target_node
]
=
auto_grad_meta
;
}
}
}
}
// Purify potential_startup_nodes, remove nodes those are the same as
// Purify potential_startup_nodes
_
, remove nodes those are the same as
// input_target_nodes
void
PurifyPotentialStartUpNodes
()
{
VLOG
(
6
)
<<
"Running in PurifyPotentialStartUpNodes"
;
if
(
input_target_nodes_inputmeta_map
.
empty
())
return
;
if
(
input_target_nodes_inputmeta_map
_
.
empty
())
return
;
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes_to_be_erased
;
for
(
auto
startup_op
:
potential_startup_nodes
)
{
auto
iter
=
input_target_nodes_inputmeta_map
.
find
(
startup_op
);
if
(
iter
!=
input_target_nodes_inputmeta_map
.
end
())
{
for
(
auto
startup_op
:
potential_startup_nodes
_
)
{
auto
iter
=
input_target_nodes_inputmeta_map
_
.
find
(
startup_op
);
if
(
iter
!=
input_target_nodes_inputmeta_map
_
.
end
())
{
potential_startup_nodes_to_be_erased
.
emplace
(
iter
->
first
);
}
}
if
(
!
potential_startup_nodes_to_be_erased
.
empty
())
{
for
(
auto
nodes
:
potential_startup_nodes_to_be_erased
)
{
potential_startup_nodes
.
erase
(
nodes
);
potential_startup_nodes
_
.
erase
(
nodes
);
}
}
}
// Remove some nodes those doesn't need to be
// stored in potential_stop_nodes
、potential_startup_nodes
// stored in potential_stop_nodes
_、potential_startup_nodes_
void
UpdateGraphInfo
()
{
// Updated potential_sotp_nodes by depending_nodes,
// Updated potential_sotp_nodes by depending_nodes
_
,
// make sure the path from root to target_node is ok
std
::
unordered_set
<
GradNodeBase
*>
_
startup_ops
;
std
::
unordered_set
<
GradNodeBase
*>
startup_ops
;
VLOG
(
6
)
<<
"Running in UpdateGraphInfo"
;
std
::
queue
<
GradNodeBase
*>
queue
;
for
(
auto
&
target_nodes_inputmeta_pair
:
input_target_nodes_inputmeta_map
)
{
for
(
auto
&
target_nodes_inputmeta_pair
:
input_target_nodes_inputmeta_map_
)
{
queue
.
emplace
(
target_nodes_inputmeta_pair
.
first
);
}
while
(
!
queue
.
empty
())
{
auto
*
target_node
=
queue
.
front
();
queue
.
pop
();
if
(
!
(
depending_nodes
)[
target_node
].
empty
())
{
auto
precedding_nodes
=
(
depending_nodes
)[
target_node
];
if
(
!
(
depending_nodes
_
)[
target_node
].
empty
())
{
auto
precedding_nodes
=
(
depending_nodes
_
)[
target_node
];
for
(
auto
pre_nodes
:
precedding_nodes
)
{
queue
.
emplace
(
pre_nodes
);
if
(
potential_stop_nodes
.
find
(
pre_nodes
)
!=
potential_stop_nodes
.
end
())
{
potential_stop_nodes
.
erase
(
pre_nodes
);
if
(
potential_stop_nodes
_
.
find
(
pre_nodes
)
!=
potential_stop_nodes
_
.
end
())
{
potential_stop_nodes
_
.
erase
(
pre_nodes
);
}
}
}
else
{
// startup_ops have no precedding nodes
VLOG
(
6
)
<<
"Emplace
_
startup_ops"
;
_
startup_ops
.
emplace
(
target_node
);
VLOG
(
6
)
<<
"Emplace startup_ops"
;
startup_ops
.
emplace
(
target_node
);
}
}
// Purify potential_startup_nodes again, remove some
// Purify potential_startup_nodes
_
again, remove some
// potential startup_nodes that unreach to input target nodes
if
(
!
_
startup_ops
.
empty
())
{
if
(
!
startup_ops
.
empty
())
{
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes_to_be_erased
;
for
(
auto
node
:
potential_startup_nodes
)
{
if
(
_
startup_ops
.
count
(
node
)
==
0
)
{
for
(
auto
node
:
potential_startup_nodes
_
)
{
if
(
startup_ops
.
count
(
node
)
==
0
)
{
VLOG
(
6
)
<<
"Set up potential_startup_nodes_to_be_erased"
;
potential_startup_nodes_to_be_erased
.
emplace
(
node
);
}
...
...
@@ -135,14 +136,14 @@ class GeneralGrad {
if
(
!
potential_startup_nodes_to_be_erased
.
empty
())
{
for
(
auto
node
:
potential_startup_nodes_to_be_erased
)
{
VLOG
(
6
)
<<
"Erase nodes in potential_startup_nodes_to_be_erased"
;
potential_startup_nodes
.
erase
(
node
);
potential_startup_nodes
_
.
erase
(
node
);
}
}
}
}
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes
、potential_stop_nodes、potential_startup_nodes
// record depending_nodes
_、potential_stop_nodes_、potential_startup_nodes_
void
GetGraphInfoBetweenTargets
(
const
std
::
queue
<
GradNodeBase
*>&
init_queue
)
{
VLOG
(
6
)
<<
"Runing In GetGraphInfoBetweenTargets"
;
...
...
@@ -164,9 +165,9 @@ class GeneralGrad {
visited
.
insert
(
node
);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
// all the next_node will be marked in potential_stop_nodes
_
bool
is_potential_stop_nodes
=
input_target_nodes_inputmeta_map
.
count
(
node
);
input_target_nodes_inputmeta_map
_
.
count
(
node
);
// Find and append next nodes
const
paddle
::
small_vector
<
std
::
vector
<
GradSlotMeta
>
,
...
...
@@ -186,40 +187,41 @@ class GeneralGrad {
// all the next_nodes of current node will be inserted to
// potential_stop_node
if
(
is_potential_stop_nodes
)
{
potential_stop_nodes
.
emplace
(
next_node
);
potential_stop_nodes
_
.
emplace
(
next_node
);
}
// Update in_degree
if
(
!
node_in_degree_map
.
count
(
next_node
))
if
(
!
node_in_degree_map
.
count
(
next_node
))
{
node_in_degree_map
[
next_node
]
=
0
;
}
node_in_degree_map
[
next_node
]
++
;
// Record depending relationship
(
depending_nodes
)[
next_node
].
emplace
(
node
);
(
depending_nodes
_
)[
next_node
].
emplace
(
node
);
queue
.
push
(
next_node
);
}
}
}
// Update Graph Info, remove some nodes in
// potential_stop_nodes
、potential_startup_nodes
、
// potential_stop_nodes
_、potential_startup_nodes_
、
UpdateGraphInfo
();
}
void
ModifyReadyQueue
(
std
::
queue
<
GradNodeBase
*>*
queue
)
{
std
::
queue
<
GradNodeBase
*>
tmp_queue
;
for
(
auto
nodes
:
potential_startup_nodes
)
{
for
(
auto
nodes
:
potential_startup_nodes
_
)
{
tmp_queue
.
emplace
(
nodes
);
}
tmp_queue
.
swap
(
*
queue
);
}
// Set result for input target grad_var when potential_startup_nodes is empty
// Set result for input target grad_var when potential_startup_nodes
_
is empty
void
SetResultForInputTargetVar
(
const
std
::
unordered_map
<
GradNodeBase
*
,
std
::
unique_ptr
<
GradTensorHolder
>>&
node_input_buffers_dict
)
{
if
(
potential_startup_nodes
.
size
()
==
0
)
{
for
(
auto
input_target_node
:
*
GetIn
P
utTargetNodesInputMetaMap
())
{
if
(
potential_startup_nodes
_
.
size
()
==
0
)
{
for
(
auto
input_target_node
:
*
GetIn
p
utTargetNodesInputMetaMap
())
{
// out rank_info of forward op
auto
rank_info
=
input_target_node
.
second
->
OutRankInfo
();
auto
iter
=
node_input_buffers_dict
.
find
(
input_target_node
.
first
);
...
...
@@ -227,7 +229,7 @@ class GeneralGrad {
auto
&
target_result
=
(
iter
->
second
)
->
Buffers
()[
rank_info
.
first
][
rank_info
.
second
];
// save the target result
results_map
[
input_target_node
.
first
]
=
target_result
;
results_map
_
[
input_target_node
.
first
]
=
target_result
;
}
}
}
...
...
@@ -236,8 +238,8 @@ class GeneralGrad {
// Set input target grad_var from node_input_buffer by inputmeta
void
SetResultForInputTargetVar
(
GradTensorHolder
input_buffers
,
GradNodeBase
*
node
)
{
auto
iter
=
GetIn
P
utTargetNodesInputMetaMap
()
->
find
(
node
);
if
(
iter
!=
GetIn
P
utTargetNodesInputMetaMap
()
->
end
())
{
auto
iter
=
GetIn
p
utTargetNodesInputMetaMap
()
->
find
(
node
);
if
(
iter
!=
GetIn
p
utTargetNodesInputMetaMap
()
->
end
())
{
VLOG
(
6
)
<<
"Get target result by by inputmeta"
;
// out rank_info of forward op
auto
rank_info
=
(
iter
->
second
)
->
OutRankInfo
();
...
...
@@ -245,7 +247,7 @@ class GeneralGrad {
auto
&
target_result
=
input_buffers
.
Buffers
()[
rank_info
.
first
][
rank_info
.
second
];
// save the target result
results_map
[
node
]
=
target_result
;
results_map
_
[
node
]
=
target_result
;
}
}
...
...
@@ -271,8 +273,8 @@ class GeneralGrad {
"input"
;
}
auto
iter
=
results_map
.
find
(
target_node
);
if
(
iter
!=
results_map
.
end
())
{
auto
iter
=
results_map
_
.
find
(
target_node
);
if
(
iter
!=
results_map
_
.
end
())
{
// set StopGradient = !create_graph
AutogradMeta
*
tensor_auto_grad_meta
=
EagerUtils
::
autograd_meta
(
&
(
iter
->
second
));
...
...
@@ -303,12 +305,12 @@ class GeneralGrad {
GetTargetNodesInfo
(
no_grad_vars
,
true
/* is_no_grad_vars */
);
// Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo
(
inputs
,
false
/* is_no_grad_vars */
);
// Purify potential
_
startup_ops, remove those nodes that are the same as
// Purify potentialstartup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes
();
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and
// potential_stop_nodes
、potential_startup_nodes
// Record the depending_nodes
_
and
// potential_stop_nodes
_、potential_startup_nodes_
GetGraphInfoBetweenTargets
(
*
queue
);
// Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output.
...
...
@@ -318,34 +320,34 @@ class GeneralGrad {
}
bool
IsPotentialStopNodes
(
GradNodeBase
*
node
)
{
return
potential_stop_nodes
.
count
(
node
);
return
potential_stop_nodes
_
.
count
(
node
);
}
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*>*
GetNoGradVarNodesInputMetaMap
()
{
return
&
no_grad_var_nodes_inputmeta_map
;
return
&
no_grad_var_nodes_inputmeta_map
_
;
}
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*>*
GetIn
P
utTargetNodesInputMetaMap
()
{
return
&
input_target_nodes_inputmeta_map
;
GetIn
p
utTargetNodesInputMetaMap
()
{
return
&
input_target_nodes_inputmeta_map
_
;
}
std
::
unordered_set
<
GradNodeBase
*>*
GetPotentialStopNodes
()
{
return
&
potential_stop_nodes
;
return
&
potential_stop_nodes
_
;
}
std
::
unordered_set
<
GradNodeBase
*>*
GetPotentialStartupNodes
()
{
return
&
potential_startup_nodes
;
return
&
potential_startup_nodes
_
;
}
void
Clear
()
{
no_grad_var_nodes_inputmeta_map
.
clear
();
input_target_nodes_inputmeta_map
.
clear
();
potential_startup_nodes
.
clear
();
potential_stop_nodes
.
clear
();
depending_nodes
.
clear
();
results_map
.
clear
();
no_grad_var_nodes_inputmeta_map
_
.
clear
();
input_target_nodes_inputmeta_map
_
.
clear
();
potential_startup_nodes
_
.
clear
();
potential_stop_nodes
_
.
clear
();
depending_nodes
_
.
clear
();
results_map
_
.
clear
();
copied_grad_nodes_
.
clear
();
orig_to_copied_node_mapping_
.
clear
();
}
...
...
@@ -426,18 +428,18 @@ class GeneralGrad {
static
GeneralGrad
*
general_grad_
;
// no_grad_vars's GradNode and GradNode's InputMeta.
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*
/* InputMeta */
>
no_grad_var_nodes_inputmeta_map
;
no_grad_var_nodes_inputmeta_map
_
;
// inputs's GradNode and GradNode's InputMeta.
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*
/* InputMeta */
>
input_target_nodes_inputmeta_map
;
input_target_nodes_inputmeta_map
_
;
// Record all the potential startup_nodes, will be changed.
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes
;
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes
_
;
// Record all the potential stop nodes, will be changed.
std
::
unordered_set
<
GradNodeBase
*>
potential_stop_nodes
;
std
::
unordered_set
<
GradNodeBase
*>
potential_stop_nodes
_
;
std
::
unordered_map
<
GradNodeBase
*
/* next node */
,
std
::
unordered_set
<
GradNodeBase
*>
/* pre nodes */
>
depending_nodes
;
std
::
unordered_map
<
GradNodeBase
*
,
paddle
::
experimental
::
Tensor
>
results_map
;
depending_nodes
_
;
std
::
unordered_map
<
GradNodeBase
*
,
paddle
::
experimental
::
Tensor
>
results_map
_
;
std
::
vector
<
std
::
shared_ptr
<
GradNodeBase
>>
copied_grad_nodes_
;
std
::
unordered_map
<
GradNodeBase
*
,
std
::
shared_ptr
<
GradNodeBase
>>
...
...
@@ -619,7 +621,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// GradTensorHolder will initialize another tensor with same tensortype,
// datatype and dims but filled with 1.0
node_input_buffers_dict
[
grad_node
]
->
CopyValueFromTensor
(
input_info
.
first
,
input_info
.
second
,
tensor
,
true
/*fill_one=true*/
);
input_info
.
first
,
input_info
.
second
,
tensor
,
/*fill_one=*/
true
);
}
// Prepare queue, potential startup_nodes
...
...
@@ -657,7 +659,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG
(
6
)
<<
"Running GradNode:"
<<
node
->
name
();
paddle
::
platform
::
RecordEvent
node_record_event
(
std
::
string
((
*
node
).
name
())
+
" grad_node"
,
std
::
string
((
*
node
).
name
()),
paddle
::
platform
::
TracerEventType
::
Operator
,
1
);
if
(
queue
.
size
()
>
1
&&
node_in_degree_map
[
node
]
!=
0
)
{
...
...
@@ -667,14 +669,15 @@ std::vector<paddle::experimental::Tensor> RunBackward(
queue
.
pop
();
// Run node: This is where Hook happens
PADDLE_ENFORCE
(
node_input_buffers_dict
.
count
(
node
),
auto
node_input_buffer_iter
=
node_input_buffers_dict
.
find
(
node
);
PADDLE_ENFORCE_NE
(
node_input_buffer_iter
,
node_input_buffers_dict
.
end
(),
paddle
::
platform
::
errors
::
Fatal
(
"Unable to find next node in the GradTensorHolder
\n
"
"Trying to run Node without configuring its GradTensorHolder."
));
std
::
unique_ptr
<
GradTensorHolder
>
node_input_buffer
=
std
::
move
(
node_input_buffer
s_dict
[
node
]
);
std
::
move
(
node_input_buffer
_iter
->
second
);
// Set input target grad_var from node_input_buffer by inputmeta
if
(
!
inputs
.
empty
()
&&
is_general_grad
)
{
...
...
@@ -715,8 +718,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
// TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict
.
erase
(
node
);
node_input_buffers_dict
.
erase
(
node_input_buffer_iter
);
// Prepare GradTensorHolder for next node
const
paddle
::
small_vector
<
std
::
vector
<
GradSlotMeta
>
,
kSlotSmallVectorSize
>&
...
...
@@ -736,8 +738,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
auto
edge_rank
=
edge
.
GetEdgeRankInfo
();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
// the same rank(i, j)
// with the same rank(i, j)
auto
next_node_shared
=
edge
.
GetMutableGradNode
();
// Next node could be nullptr if it is leaf tensor with no
...
...
paddle/fluid/eager/grad_node_info.cc
浏览文件 @
778ea4ec
...
...
@@ -36,6 +36,31 @@
**/
namespace
egr
{
static
void
CheckTensor
(
const
paddle
::
experimental
::
Tensor
&
pre
,
const
paddle
::
experimental
::
Tensor
&
post
)
{
if
(
!
pre
.
initialized
()
&&
post
.
initialized
())
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
PermissionDenied
(
"The tensor in before and after hook are not consistent"
));
}
if
(
pre
.
initialized
()
&&
post
.
initialized
())
{
VLOG
(
4
)
<<
paddle
::
framework
::
DataType2String
(
pre
.
dtype
())
<<
" "
<<
paddle
::
framework
::
DataType2String
(
post
.
dtype
());
PADDLE_ENFORCE_EQ
(
pre
.
dtype
(),
post
.
dtype
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent"
,
paddle
::
framework
::
DataType2String
(
pre
.
dtype
()),
paddle
::
framework
::
DataType2String
(
post
.
dtype
())));
PADDLE_ENFORCE_EQ
(
pre
.
place
(),
post
.
place
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent"
,
pre
.
place
().
DebugString
(),
post
.
place
().
DebugString
()));
}
}
GradNodeBase
::
GradNodeBase
(
size_t
bwd_in_slot_num
,
size_t
bwd_out_slot_num
)
{
VLOG
(
6
)
<<
"Construct GradNodeBase"
;
bwd_in_meta_
.
resize
(
bwd_in_slot_num
);
...
...
@@ -271,7 +296,7 @@ void GradNodeBase::SetGradOutMeta(
// Only Copy Meta
phi
::
DenseTensor
*
dense_tensor
=
static_cast
<
phi
::
DenseTensor
*>
(
fwd_in_tensor
.
impl
().
get
());
PADDLE_ENFORCE_NE
(
dense_tensor
->
meta
().
dtype
,
phi
::
DataType
::
UNDEFINED
,
PADDLE_ENFORCE_NE
(
dense_tensor
->
dtype
()
,
phi
::
DataType
::
UNDEFINED
,
paddle
::
platform
::
errors
::
Fatal
(
"Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
...
...
paddle/fluid/eager/grad_node_info.h
浏览文件 @
778ea4ec
...
...
@@ -30,32 +30,23 @@ namespace egr {
* The GradNodeBase will be held in autograd_meta, and it is also a member of
* Edge, which indicates the edge of backward graph.
*
* TODO
:(yangzhanlue)
GradNodeBase will also in charge of get the correct input
* TODO
(yangzhanlue):
GradNodeBase will also in charge of get the correct input
* from GradOpDescMaker to GradNodeBase.
*
* NOTE:GradNodeBase has a method named run, this method should be overrided by
* the
* specific derived class, it will prepare backward inputs and double backward's
* depends. Then, it will call C++ API of backward kernel functions to finish
* backward computation.
* NOTE: GradNodeBase has a method named run, this method should be overrided by
* the specific derived class, it will prepare backward inputs and double
* backward's depends. Then, it will call C++ API of backward kernel functions
* to finish backward computation.
*
* NOTE:GradNodeBase holds its own inputs and Outputs
* NOTE:
GradNodeBase holds its own inputs and Outputs
*
* Edge is defined to descripe depend of backward, an Edge is what linked
* between two
* node, it should contain a Node and rank of this Node (this is used to
* indicate which
* input of grad this edge belong).
* */
* between two node, it should contain a Node and rank of this Node (this is
* used to indicate which input of grad this edge belong).
**/
class
AutogradMeta
;
class
GradNodeBase
;
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators
* whose backward logic is depends on if it has some specific inputs or outputs.
* So, we need a meta info
* to record it's needs.
* **/
class
Edge
{
public:
// Default constructor for Edges in order to construct it for AutogradMeta
...
...
@@ -64,8 +55,7 @@ class Edge {
// In real use cases we should create Edge from grad node and input rank which
// indicate which edge it is.
// Since we have slot design in operators we will have to locate an edge with
// slot
// and rank.
// slot and rank.
Edge
(
const
std
::
shared_ptr
<
GradNodeBase
>&
grad_node
,
size_t
in_slot_id
,
size_t
in_rank
)
:
in_slot_id_
(
in_slot_id
),
in_rank_
(
in_rank
),
grad_node_
(
grad_node
)
{}
...
...
@@ -120,6 +110,12 @@ class Edge {
size_t
in_rank_
;
std
::
shared_ptr
<
GradNodeBase
>
grad_node_
{
nullptr
};
};
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators whose backward logic is depends on if it has some
* specific inputs or outputs. So, we need a meta info to record it's needs.
**/
class
GradSlotMeta
{
public:
GradSlotMeta
()
=
default
;
...
...
@@ -171,16 +167,13 @@ class GradNodeBase {
/**
* operator() designed to contian the real backward execution logic, it should
* be
* overrided by derived class defined for each operator. It accepts a vector
* of
* Tensor which contains grads input of current operator
* be overrided by derived class defined for each operator. It accepts a
* vector of Tensor which contains grads input of current operator
*
* Note: why we need backward inputs and outputs construct as vector of vector
* of paddle::experimental::Tensor?
* Since all of paddle op composite in form of {"Slot name ", vector<Var>},
* so, vector of vector
* is better choice to fit this format.
* so, vector of vector is better choice to fit this format.
* **/
virtual
paddle
::
small_vector
<
std
::
vector
<
paddle
::
experimental
::
Tensor
>
,
kSlotSmallVectorSize
>
...
...
@@ -294,36 +287,12 @@ class GradNodeBase {
/* slot id */
size_t
,
/* rank */
size_t
,
/* hook */
std
::
shared_ptr
<
TensorHook
>>>
gradient_hooks_
;
int64_t
next_hook_id_
{
0
};
// We handle complex to real conversion only if any complex GradIn is involved
bool
need_complex_to_real_
=
false
;
int64_t
next_hook_id_
{
0
};
bool
is_tensor_wrappers_cleared_
=
false
;
};
inline
void
CheckTensor
(
const
paddle
::
experimental
::
Tensor
&
pre
,
const
paddle
::
experimental
::
Tensor
&
post
)
{
if
(
!
pre
.
initialized
()
&&
post
.
initialized
())
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
PermissionDenied
(
"The tensor in before and after hook are not consistent"
));
}
if
(
pre
.
initialized
()
&&
post
.
initialized
())
{
VLOG
(
4
)
<<
paddle
::
framework
::
DataType2String
(
pre
.
dtype
())
<<
" "
<<
paddle
::
framework
::
DataType2String
(
post
.
dtype
());
PADDLE_ENFORCE_EQ
(
pre
.
dtype
(),
post
.
dtype
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent"
,
paddle
::
framework
::
DataType2String
(
pre
.
dtype
()),
paddle
::
framework
::
DataType2String
(
post
.
dtype
())));
PADDLE_ENFORCE_EQ
(
pre
.
place
(),
post
.
place
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent"
,
pre
.
place
().
DebugString
(),
post
.
place
().
DebugString
()));
}
}
}
// namespace egr
paddle/fluid/eager/tensor_wrapper.h
浏览文件 @
778ea4ec
...
...
@@ -88,6 +88,7 @@ class TensorWrapper {
}
else
{
intermidiate_tensor_
.
set_impl
(
tensor
.
impl
());
}
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_
.
set_name
(
tensor
.
name
()
+
"@Saved"
);
...
...
@@ -118,24 +119,25 @@ class TensorWrapper {
paddle
::
experimental
::
Tensor
recovered_tensor
=
intermidiate_tensor_
;
std
::
shared_ptr
<
GradNodeBase
>
new_grad_node
=
weak_grad_node_
.
lock
();
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
}
else
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empth GradNode"
;
}
auto
*
intermediate_autograd_meta
=
EagerUtils
::
unsafe_autograd_meta
(
intermidiate_tensor_
);
auto
p_ab_autograd_meta
=
std
::
make_shared
<
AutogradMeta
>
(
*
intermediate_autograd_meta
);
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
p_ab_autograd_meta
->
SetGradNode
(
new_grad_node
);
}
else
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empth GradNode"
;
}
recovered_tensor
.
set_autograd_meta
(
p_ab_autograd_meta
);
return
recovered_tensor
;
}
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
private:
void
check_inplace_version
()
{
if
(
no_need_buffer_
)
{
VLOG
(
6
)
<<
"There's no need to check inplace_version because "
...
...
@@ -170,8 +172,6 @@ class TensorWrapper {
}
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
private:
bool
full_reserved_
=
false
;
bool
no_need_buffer_
=
false
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录