Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
778ea4ec
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
778ea4ec
编写于
5月 09, 2022
作者:
C
Chen Weihang
提交者:
GitHub
5月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Polish grad code details (#42536)
* polish grad details * polish detail by comment
上级
13bcb7cd
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
128 addition
and
133 deletion
+128
-133
paddle/fluid/eager/backward.cc
paddle/fluid/eager/backward.cc
+72
-71
paddle/fluid/eager/grad_node_info.cc
paddle/fluid/eager/grad_node_info.cc
+26
-1
paddle/fluid/eager/grad_node_info.h
paddle/fluid/eager/grad_node_info.h
+22
-53
paddle/fluid/eager/tensor_wrapper.h
paddle/fluid/eager/tensor_wrapper.h
+8
-8
未找到文件。
paddle/fluid/eager/backward.cc
浏览文件 @
778ea4ec
...
...
@@ -66,68 +66,69 @@ class GeneralGrad {
"stop_gradient=True."
,
msg
,
i
));
if
(
is_no_grad_vars
)
{
(
no_grad_var_nodes_inputmeta_map
)[
target_node
]
=
auto_grad_meta
;
(
no_grad_var_nodes_inputmeta_map
_
)[
target_node
]
=
auto_grad_meta
;
}
else
{
// normal input
(
input_target_nodes_inputmeta_map
)[
target_node
]
=
auto_grad_meta
;
(
input_target_nodes_inputmeta_map
_
)[
target_node
]
=
auto_grad_meta
;
}
}
}
}
// Purify potential_startup_nodes, remove nodes those are the same as
// Purify potential_startup_nodes
_
, remove nodes those are the same as
// input_target_nodes
void
PurifyPotentialStartUpNodes
()
{
VLOG
(
6
)
<<
"Running in PurifyPotentialStartUpNodes"
;
if
(
input_target_nodes_inputmeta_map
.
empty
())
return
;
if
(
input_target_nodes_inputmeta_map
_
.
empty
())
return
;
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes_to_be_erased
;
for
(
auto
startup_op
:
potential_startup_nodes
)
{
auto
iter
=
input_target_nodes_inputmeta_map
.
find
(
startup_op
);
if
(
iter
!=
input_target_nodes_inputmeta_map
.
end
())
{
for
(
auto
startup_op
:
potential_startup_nodes
_
)
{
auto
iter
=
input_target_nodes_inputmeta_map
_
.
find
(
startup_op
);
if
(
iter
!=
input_target_nodes_inputmeta_map
_
.
end
())
{
potential_startup_nodes_to_be_erased
.
emplace
(
iter
->
first
);
}
}
if
(
!
potential_startup_nodes_to_be_erased
.
empty
())
{
for
(
auto
nodes
:
potential_startup_nodes_to_be_erased
)
{
potential_startup_nodes
.
erase
(
nodes
);
potential_startup_nodes
_
.
erase
(
nodes
);
}
}
}
// Remove some nodes those doesn't need to be
// stored in potential_stop_nodes
、potential_startup_nodes
// stored in potential_stop_nodes
_、potential_startup_nodes_
void
UpdateGraphInfo
()
{
// Updated potential_sotp_nodes by depending_nodes,
// Updated potential_sotp_nodes by depending_nodes
_
,
// make sure the path from root to target_node is ok
std
::
unordered_set
<
GradNodeBase
*>
_
startup_ops
;
std
::
unordered_set
<
GradNodeBase
*>
startup_ops
;
VLOG
(
6
)
<<
"Running in UpdateGraphInfo"
;
std
::
queue
<
GradNodeBase
*>
queue
;
for
(
auto
&
target_nodes_inputmeta_pair
:
input_target_nodes_inputmeta_map
)
{
for
(
auto
&
target_nodes_inputmeta_pair
:
input_target_nodes_inputmeta_map_
)
{
queue
.
emplace
(
target_nodes_inputmeta_pair
.
first
);
}
while
(
!
queue
.
empty
())
{
auto
*
target_node
=
queue
.
front
();
queue
.
pop
();
if
(
!
(
depending_nodes
)[
target_node
].
empty
())
{
auto
precedding_nodes
=
(
depending_nodes
)[
target_node
];
if
(
!
(
depending_nodes
_
)[
target_node
].
empty
())
{
auto
precedding_nodes
=
(
depending_nodes
_
)[
target_node
];
for
(
auto
pre_nodes
:
precedding_nodes
)
{
queue
.
emplace
(
pre_nodes
);
if
(
potential_stop_nodes
.
find
(
pre_nodes
)
!=
potential_stop_nodes
.
end
())
{
potential_stop_nodes
.
erase
(
pre_nodes
);
if
(
potential_stop_nodes
_
.
find
(
pre_nodes
)
!=
potential_stop_nodes
_
.
end
())
{
potential_stop_nodes
_
.
erase
(
pre_nodes
);
}
}
}
else
{
// startup_ops have no precedding nodes
VLOG
(
6
)
<<
"Emplace
_
startup_ops"
;
_
startup_ops
.
emplace
(
target_node
);
VLOG
(
6
)
<<
"Emplace startup_ops"
;
startup_ops
.
emplace
(
target_node
);
}
}
// Purify potential_startup_nodes again, remove some
// Purify potential_startup_nodes
_
again, remove some
// potential startup_nodes that unreach to input target nodes
if
(
!
_
startup_ops
.
empty
())
{
if
(
!
startup_ops
.
empty
())
{
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes_to_be_erased
;
for
(
auto
node
:
potential_startup_nodes
)
{
if
(
_
startup_ops
.
count
(
node
)
==
0
)
{
for
(
auto
node
:
potential_startup_nodes
_
)
{
if
(
startup_ops
.
count
(
node
)
==
0
)
{
VLOG
(
6
)
<<
"Set up potential_startup_nodes_to_be_erased"
;
potential_startup_nodes_to_be_erased
.
emplace
(
node
);
}
...
...
@@ -135,14 +136,14 @@ class GeneralGrad {
if
(
!
potential_startup_nodes_to_be_erased
.
empty
())
{
for
(
auto
node
:
potential_startup_nodes_to_be_erased
)
{
VLOG
(
6
)
<<
"Erase nodes in potential_startup_nodes_to_be_erased"
;
potential_startup_nodes
.
erase
(
node
);
potential_startup_nodes
_
.
erase
(
node
);
}
}
}
}
// Get Graph Info Betweent input target GradNode and outputs,
// record depending_nodes
、potential_stop_nodes、potential_startup_nodes
// record depending_nodes
_、potential_stop_nodes_、potential_startup_nodes_
void
GetGraphInfoBetweenTargets
(
const
std
::
queue
<
GradNodeBase
*>&
init_queue
)
{
VLOG
(
6
)
<<
"Runing In GetGraphInfoBetweenTargets"
;
...
...
@@ -164,9 +165,9 @@ class GeneralGrad {
visited
.
insert
(
node
);
// Check node is target_nodes or not, if node is not target_node,
// all the next_node will be marked in potential_stop_nodes
// all the next_node will be marked in potential_stop_nodes
_
bool
is_potential_stop_nodes
=
input_target_nodes_inputmeta_map
.
count
(
node
);
input_target_nodes_inputmeta_map
_
.
count
(
node
);
// Find and append next nodes
const
paddle
::
small_vector
<
std
::
vector
<
GradSlotMeta
>
,
...
...
@@ -186,40 +187,41 @@ class GeneralGrad {
// all the next_nodes of current node will be inserted to
// potential_stop_node
if
(
is_potential_stop_nodes
)
{
potential_stop_nodes
.
emplace
(
next_node
);
potential_stop_nodes
_
.
emplace
(
next_node
);
}
// Update in_degree
if
(
!
node_in_degree_map
.
count
(
next_node
))
if
(
!
node_in_degree_map
.
count
(
next_node
))
{
node_in_degree_map
[
next_node
]
=
0
;
}
node_in_degree_map
[
next_node
]
++
;
// Record depending relationship
(
depending_nodes
)[
next_node
].
emplace
(
node
);
(
depending_nodes
_
)[
next_node
].
emplace
(
node
);
queue
.
push
(
next_node
);
}
}
}
// Update Graph Info, remove some nodes in
// potential_stop_nodes
、potential_startup_nodes
、
// potential_stop_nodes
_、potential_startup_nodes_
、
UpdateGraphInfo
();
}
void
ModifyReadyQueue
(
std
::
queue
<
GradNodeBase
*>*
queue
)
{
std
::
queue
<
GradNodeBase
*>
tmp_queue
;
for
(
auto
nodes
:
potential_startup_nodes
)
{
for
(
auto
nodes
:
potential_startup_nodes
_
)
{
tmp_queue
.
emplace
(
nodes
);
}
tmp_queue
.
swap
(
*
queue
);
}
// Set result for input target grad_var when potential_startup_nodes is empty
// Set result for input target grad_var when potential_startup_nodes
_
is empty
void
SetResultForInputTargetVar
(
const
std
::
unordered_map
<
GradNodeBase
*
,
std
::
unique_ptr
<
GradTensorHolder
>>&
node_input_buffers_dict
)
{
if
(
potential_startup_nodes
.
size
()
==
0
)
{
for
(
auto
input_target_node
:
*
GetIn
P
utTargetNodesInputMetaMap
())
{
if
(
potential_startup_nodes
_
.
size
()
==
0
)
{
for
(
auto
input_target_node
:
*
GetIn
p
utTargetNodesInputMetaMap
())
{
// out rank_info of forward op
auto
rank_info
=
input_target_node
.
second
->
OutRankInfo
();
auto
iter
=
node_input_buffers_dict
.
find
(
input_target_node
.
first
);
...
...
@@ -227,7 +229,7 @@ class GeneralGrad {
auto
&
target_result
=
(
iter
->
second
)
->
Buffers
()[
rank_info
.
first
][
rank_info
.
second
];
// save the target result
results_map
[
input_target_node
.
first
]
=
target_result
;
results_map
_
[
input_target_node
.
first
]
=
target_result
;
}
}
}
...
...
@@ -236,8 +238,8 @@ class GeneralGrad {
// Set input target grad_var from node_input_buffer by inputmeta
void
SetResultForInputTargetVar
(
GradTensorHolder
input_buffers
,
GradNodeBase
*
node
)
{
auto
iter
=
GetIn
P
utTargetNodesInputMetaMap
()
->
find
(
node
);
if
(
iter
!=
GetIn
P
utTargetNodesInputMetaMap
()
->
end
())
{
auto
iter
=
GetIn
p
utTargetNodesInputMetaMap
()
->
find
(
node
);
if
(
iter
!=
GetIn
p
utTargetNodesInputMetaMap
()
->
end
())
{
VLOG
(
6
)
<<
"Get target result by by inputmeta"
;
// out rank_info of forward op
auto
rank_info
=
(
iter
->
second
)
->
OutRankInfo
();
...
...
@@ -245,7 +247,7 @@ class GeneralGrad {
auto
&
target_result
=
input_buffers
.
Buffers
()[
rank_info
.
first
][
rank_info
.
second
];
// save the target result
results_map
[
node
]
=
target_result
;
results_map
_
[
node
]
=
target_result
;
}
}
...
...
@@ -271,8 +273,8 @@ class GeneralGrad {
"input"
;
}
auto
iter
=
results_map
.
find
(
target_node
);
if
(
iter
!=
results_map
.
end
())
{
auto
iter
=
results_map
_
.
find
(
target_node
);
if
(
iter
!=
results_map
_
.
end
())
{
// set StopGradient = !create_graph
AutogradMeta
*
tensor_auto_grad_meta
=
EagerUtils
::
autograd_meta
(
&
(
iter
->
second
));
...
...
@@ -303,12 +305,12 @@ class GeneralGrad {
GetTargetNodesInfo
(
no_grad_vars
,
true
/* is_no_grad_vars */
);
// Get inputs's GradNodes and InputMeta Info
GetTargetNodesInfo
(
inputs
,
false
/* is_no_grad_vars */
);
// Purify potential
_
startup_ops, remove those nodes that are the same as
// Purify potentialstartup_ops, remove those nodes that are the same as
// input_target_nodes
PurifyPotentialStartUpNodes
();
// Get Graph Info Betweent input target gradnode and outputs
// Record the depending_nodes and
// potential_stop_nodes
、potential_startup_nodes
// Record the depending_nodes
_
and
// potential_stop_nodes
_、potential_startup_nodes_
GetGraphInfoBetweenTargets
(
*
queue
);
// Reset queue. Queue is empty only when
// 1.input equals to output. 2.input can not reach to output.
...
...
@@ -318,34 +320,34 @@ class GeneralGrad {
}
bool
IsPotentialStopNodes
(
GradNodeBase
*
node
)
{
return
potential_stop_nodes
.
count
(
node
);
return
potential_stop_nodes
_
.
count
(
node
);
}
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*>*
GetNoGradVarNodesInputMetaMap
()
{
return
&
no_grad_var_nodes_inputmeta_map
;
return
&
no_grad_var_nodes_inputmeta_map
_
;
}
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*>*
GetIn
P
utTargetNodesInputMetaMap
()
{
return
&
input_target_nodes_inputmeta_map
;
GetIn
p
utTargetNodesInputMetaMap
()
{
return
&
input_target_nodes_inputmeta_map
_
;
}
std
::
unordered_set
<
GradNodeBase
*>*
GetPotentialStopNodes
()
{
return
&
potential_stop_nodes
;
return
&
potential_stop_nodes
_
;
}
std
::
unordered_set
<
GradNodeBase
*>*
GetPotentialStartupNodes
()
{
return
&
potential_startup_nodes
;
return
&
potential_startup_nodes
_
;
}
void
Clear
()
{
no_grad_var_nodes_inputmeta_map
.
clear
();
input_target_nodes_inputmeta_map
.
clear
();
potential_startup_nodes
.
clear
();
potential_stop_nodes
.
clear
();
depending_nodes
.
clear
();
results_map
.
clear
();
no_grad_var_nodes_inputmeta_map
_
.
clear
();
input_target_nodes_inputmeta_map
_
.
clear
();
potential_startup_nodes
_
.
clear
();
potential_stop_nodes
_
.
clear
();
depending_nodes
_
.
clear
();
results_map
_
.
clear
();
copied_grad_nodes_
.
clear
();
orig_to_copied_node_mapping_
.
clear
();
}
...
...
@@ -426,18 +428,18 @@ class GeneralGrad {
static
GeneralGrad
*
general_grad_
;
// no_grad_vars's GradNode and GradNode's InputMeta.
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*
/* InputMeta */
>
no_grad_var_nodes_inputmeta_map
;
no_grad_var_nodes_inputmeta_map
_
;
// inputs's GradNode and GradNode's InputMeta.
std
::
unordered_map
<
GradNodeBase
*
,
AutogradMeta
*
/* InputMeta */
>
input_target_nodes_inputmeta_map
;
input_target_nodes_inputmeta_map
_
;
// Record all the potential startup_nodes, will be changed.
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes
;
std
::
unordered_set
<
GradNodeBase
*>
potential_startup_nodes
_
;
// Record all the potential stop nodes, will be changed.
std
::
unordered_set
<
GradNodeBase
*>
potential_stop_nodes
;
std
::
unordered_set
<
GradNodeBase
*>
potential_stop_nodes
_
;
std
::
unordered_map
<
GradNodeBase
*
/* next node */
,
std
::
unordered_set
<
GradNodeBase
*>
/* pre nodes */
>
depending_nodes
;
std
::
unordered_map
<
GradNodeBase
*
,
paddle
::
experimental
::
Tensor
>
results_map
;
depending_nodes
_
;
std
::
unordered_map
<
GradNodeBase
*
,
paddle
::
experimental
::
Tensor
>
results_map
_
;
std
::
vector
<
std
::
shared_ptr
<
GradNodeBase
>>
copied_grad_nodes_
;
std
::
unordered_map
<
GradNodeBase
*
,
std
::
shared_ptr
<
GradNodeBase
>>
...
...
@@ -619,7 +621,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// GradTensorHolder will initialize another tensor with same tensortype,
// datatype and dims but filled with 1.0
node_input_buffers_dict
[
grad_node
]
->
CopyValueFromTensor
(
input_info
.
first
,
input_info
.
second
,
tensor
,
true
/*fill_one=true*/
);
input_info
.
first
,
input_info
.
second
,
tensor
,
/*fill_one=*/
true
);
}
// Prepare queue, potential startup_nodes
...
...
@@ -657,7 +659,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG
(
6
)
<<
"Running GradNode:"
<<
node
->
name
();
paddle
::
platform
::
RecordEvent
node_record_event
(
std
::
string
((
*
node
).
name
())
+
" grad_node"
,
std
::
string
((
*
node
).
name
()),
paddle
::
platform
::
TracerEventType
::
Operator
,
1
);
if
(
queue
.
size
()
>
1
&&
node_in_degree_map
[
node
]
!=
0
)
{
...
...
@@ -667,14 +669,15 @@ std::vector<paddle::experimental::Tensor> RunBackward(
queue
.
pop
();
// Run node: This is where Hook happens
PADDLE_ENFORCE
(
node_input_buffers_dict
.
count
(
node
),
auto
node_input_buffer_iter
=
node_input_buffers_dict
.
find
(
node
);
PADDLE_ENFORCE_NE
(
node_input_buffer_iter
,
node_input_buffers_dict
.
end
(),
paddle
::
platform
::
errors
::
Fatal
(
"Unable to find next node in the GradTensorHolder
\n
"
"Trying to run Node without configuring its GradTensorHolder."
));
std
::
unique_ptr
<
GradTensorHolder
>
node_input_buffer
=
std
::
move
(
node_input_buffer
s_dict
[
node
]
);
std
::
move
(
node_input_buffer
_iter
->
second
);
// Set input target grad_var from node_input_buffer by inputmeta
if
(
!
inputs
.
empty
()
&&
is_general_grad
)
{
...
...
@@ -715,8 +718,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
// TODO(jiabin): Should we erase it or find a more efficient way.
node_input_buffers_dict
.
erase
(
node
);
node_input_buffers_dict
.
erase
(
node_input_buffer_iter
);
// Prepare GradTensorHolder for next node
const
paddle
::
small_vector
<
std
::
vector
<
GradSlotMeta
>
,
kSlotSmallVectorSize
>&
...
...
@@ -736,8 +738,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
}
auto
edge_rank
=
edge
.
GetEdgeRankInfo
();
// Since we make edge has as same rank as bwd outputs, we indexing them
// with
// the same rank(i, j)
// with the same rank(i, j)
auto
next_node_shared
=
edge
.
GetMutableGradNode
();
// Next node could be nullptr if it is leaf tensor with no
...
...
paddle/fluid/eager/grad_node_info.cc
浏览文件 @
778ea4ec
...
...
@@ -36,6 +36,31 @@
**/
namespace
egr
{
static
void
CheckTensor
(
const
paddle
::
experimental
::
Tensor
&
pre
,
const
paddle
::
experimental
::
Tensor
&
post
)
{
if
(
!
pre
.
initialized
()
&&
post
.
initialized
())
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
PermissionDenied
(
"The tensor in before and after hook are not consistent"
));
}
if
(
pre
.
initialized
()
&&
post
.
initialized
())
{
VLOG
(
4
)
<<
paddle
::
framework
::
DataType2String
(
pre
.
dtype
())
<<
" "
<<
paddle
::
framework
::
DataType2String
(
post
.
dtype
());
PADDLE_ENFORCE_EQ
(
pre
.
dtype
(),
post
.
dtype
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent"
,
paddle
::
framework
::
DataType2String
(
pre
.
dtype
()),
paddle
::
framework
::
DataType2String
(
post
.
dtype
())));
PADDLE_ENFORCE_EQ
(
pre
.
place
(),
post
.
place
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent"
,
pre
.
place
().
DebugString
(),
post
.
place
().
DebugString
()));
}
}
GradNodeBase
::
GradNodeBase
(
size_t
bwd_in_slot_num
,
size_t
bwd_out_slot_num
)
{
VLOG
(
6
)
<<
"Construct GradNodeBase"
;
bwd_in_meta_
.
resize
(
bwd_in_slot_num
);
...
...
@@ -271,7 +296,7 @@ void GradNodeBase::SetGradOutMeta(
// Only Copy Meta
phi
::
DenseTensor
*
dense_tensor
=
static_cast
<
phi
::
DenseTensor
*>
(
fwd_in_tensor
.
impl
().
get
());
PADDLE_ENFORCE_NE
(
dense_tensor
->
meta
().
dtype
,
phi
::
DataType
::
UNDEFINED
,
PADDLE_ENFORCE_NE
(
dense_tensor
->
dtype
()
,
phi
::
DataType
::
UNDEFINED
,
paddle
::
platform
::
errors
::
Fatal
(
"Attempting to copy DenseTensorMeta "
"with phi::DataType::UNDEFINED,"
...
...
paddle/fluid/eager/grad_node_info.h
浏览文件 @
778ea4ec
...
...
@@ -30,32 +30,23 @@ namespace egr {
* The GradNodeBase will be held in autograd_meta, and it is also a member of
* Edge, which indicates the edge of backward graph.
*
* TODO
:(yangzhanlue)
GradNodeBase will also in charge of get the correct input
* TODO
(yangzhanlue):
GradNodeBase will also in charge of get the correct input
* from GradOpDescMaker to GradNodeBase.
*
* NOTE:GradNodeBase has a method named run, this method should be overrided by
* the
* specific derived class, it will prepare backward inputs and double backward's
* depends. Then, it will call C++ API of backward kernel functions to finish
* backward computation.
* NOTE: GradNodeBase has a method named run, this method should be overrided by
* the specific derived class, it will prepare backward inputs and double
* backward's depends. Then, it will call C++ API of backward kernel functions
* to finish backward computation.
*
* NOTE:GradNodeBase holds its own inputs and Outputs
* NOTE:
GradNodeBase holds its own inputs and Outputs
*
* Edge is defined to descripe depend of backward, an Edge is what linked
* between two
* node, it should contain a Node and rank of this Node (this is used to
* indicate which
* input of grad this edge belong).
* */
* between two node, it should contain a Node and rank of this Node (this is
* used to indicate which input of grad this edge belong).
**/
class
AutogradMeta
;
class
GradNodeBase
;
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators
* whose backward logic is depends on if it has some specific inputs or outputs.
* So, we need a meta info
* to record it's needs.
* **/
class
Edge
{
public:
// Default constructor for Edges in order to construct it for AutogradMeta
...
...
@@ -64,8 +55,7 @@ class Edge {
// In real use cases we should create Edge from grad node and input rank which
// indicate which edge it is.
// Since we have slot design in operators we will have to locate an edge with
// slot
// and rank.
// slot and rank.
Edge
(
const
std
::
shared_ptr
<
GradNodeBase
>&
grad_node
,
size_t
in_slot_id
,
size_t
in_rank
)
:
in_slot_id_
(
in_slot_id
),
in_rank_
(
in_rank
),
grad_node_
(
grad_node
)
{}
...
...
@@ -120,6 +110,12 @@ class Edge {
size_t
in_rank_
;
std
::
shared_ptr
<
GradNodeBase
>
grad_node_
{
nullptr
};
};
/**
* GradSlotMeta is used to Record Forward Tensor info to backward, since paddle
* has lots of operators whose backward logic is depends on if it has some
* specific inputs or outputs. So, we need a meta info to record it's needs.
**/
class
GradSlotMeta
{
public:
GradSlotMeta
()
=
default
;
...
...
@@ -171,16 +167,13 @@ class GradNodeBase {
/**
* operator() designed to contian the real backward execution logic, it should
* be
* overrided by derived class defined for each operator. It accepts a vector
* of
* Tensor which contains grads input of current operator
* be overrided by derived class defined for each operator. It accepts a
* vector of Tensor which contains grads input of current operator
*
* Note: why we need backward inputs and outputs construct as vector of vector
* of paddle::experimental::Tensor?
* Since all of paddle op composite in form of {"Slot name ", vector<Var>},
* so, vector of vector
* is better choice to fit this format.
* so, vector of vector is better choice to fit this format.
* **/
virtual
paddle
::
small_vector
<
std
::
vector
<
paddle
::
experimental
::
Tensor
>
,
kSlotSmallVectorSize
>
...
...
@@ -294,36 +287,12 @@ class GradNodeBase {
/* slot id */
size_t
,
/* rank */
size_t
,
/* hook */
std
::
shared_ptr
<
TensorHook
>>>
gradient_hooks_
;
int64_t
next_hook_id_
{
0
};
// We handle complex to real conversion only if any complex GradIn is involved
bool
need_complex_to_real_
=
false
;
int64_t
next_hook_id_
{
0
};
bool
is_tensor_wrappers_cleared_
=
false
;
};
inline
void
CheckTensor
(
const
paddle
::
experimental
::
Tensor
&
pre
,
const
paddle
::
experimental
::
Tensor
&
post
)
{
if
(
!
pre
.
initialized
()
&&
post
.
initialized
())
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
PermissionDenied
(
"The tensor in before and after hook are not consistent"
));
}
if
(
pre
.
initialized
()
&&
post
.
initialized
())
{
VLOG
(
4
)
<<
paddle
::
framework
::
DataType2String
(
pre
.
dtype
())
<<
" "
<<
paddle
::
framework
::
DataType2String
(
post
.
dtype
());
PADDLE_ENFORCE_EQ
(
pre
.
dtype
(),
post
.
dtype
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The dtype of tensor before(%s) and after(%s) hook are not "
"consistent"
,
paddle
::
framework
::
DataType2String
(
pre
.
dtype
()),
paddle
::
framework
::
DataType2String
(
post
.
dtype
())));
PADDLE_ENFORCE_EQ
(
pre
.
place
(),
post
.
place
(),
paddle
::
platform
::
errors
::
PermissionDenied
(
"The place of tensor before(%s) and after(%s) "
"hook are not consistent"
,
pre
.
place
().
DebugString
(),
post
.
place
().
DebugString
()));
}
}
}
// namespace egr
paddle/fluid/eager/tensor_wrapper.h
浏览文件 @
778ea4ec
...
...
@@ -88,6 +88,7 @@ class TensorWrapper {
}
else
{
intermidiate_tensor_
.
set_impl
(
tensor
.
impl
());
}
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_
.
set_name
(
tensor
.
name
()
+
"@Saved"
);
...
...
@@ -118,24 +119,25 @@ class TensorWrapper {
paddle
::
experimental
::
Tensor
recovered_tensor
=
intermidiate_tensor_
;
std
::
shared_ptr
<
GradNodeBase
>
new_grad_node
=
weak_grad_node_
.
lock
();
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
}
else
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empth GradNode"
;
}
auto
*
intermediate_autograd_meta
=
EagerUtils
::
unsafe_autograd_meta
(
intermidiate_tensor_
);
auto
p_ab_autograd_meta
=
std
::
make_shared
<
AutogradMeta
>
(
*
intermediate_autograd_meta
);
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
p_ab_autograd_meta
->
SetGradNode
(
new_grad_node
);
}
else
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empth GradNode"
;
}
recovered_tensor
.
set_autograd_meta
(
p_ab_autograd_meta
);
return
recovered_tensor
;
}
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
private:
void
check_inplace_version
()
{
if
(
no_need_buffer_
)
{
VLOG
(
6
)
<<
"There's no need to check inplace_version because "
...
...
@@ -170,8 +172,6 @@ class TensorWrapper {
}
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
private:
bool
full_reserved_
=
false
;
bool
no_need_buffer_
=
false
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录