Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7ba85aca
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ba85aca
编写于
4月 14, 2021
作者:
C
Chen Weihang
提交者:
GitHub
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add inner register backward hook method for Tensor (#32171)
* add register backward hook method * add leaf grad accumullated test
上级
f3e49c40
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
280 addition
and
135 deletion
+280
-135
paddle/fluid/imperative/basic_engine.cc
paddle/fluid/imperative/basic_engine.cc
+5
-5
paddle/fluid/imperative/gradient_accumulator.cc
paddle/fluid/imperative/gradient_accumulator.cc
+7
-7
paddle/fluid/imperative/hooks.h
paddle/fluid/imperative/hooks.h
+27
-37
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+12
-10
paddle/fluid/imperative/reducer.cc
paddle/fluid/imperative/reducer.cc
+2
-3
paddle/fluid/imperative/tests/test_hooks.cc
paddle/fluid/imperative/tests/test_hooks.cc
+46
-20
paddle/fluid/imperative/variable_wrapper.h
paddle/fluid/imperative/variable_wrapper.h
+26
-21
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+60
-28
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
...paddle/fluid/tests/unittests/test_tensor_register_hook.py
+95
-4
未找到文件。
paddle/fluid/imperative/basic_engine.cc
浏览文件 @
7ba85aca
...
...
@@ -284,15 +284,15 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
for
(
const
auto
&
pair
:
bwd_ins
)
{
for
(
size_t
i
=
0
;
i
<
pair
.
second
.
size
();
++
i
)
{
auto
&
var
=
pair
.
second
[
i
];
if
(
var
->
HasHook
())
{
if
(
var
->
Has
VariableWrapper
Hook
())
{
if
(
tmp_ins_ptr
==
nullptr
)
{
tmp_ins_ptr
=
std
::
make_shared
<
NameVarMap
<
VariableWrapper
>>
(
bwd_ins
);
}
VLOG
(
3
)
<<
"Call "
<<
var
->
Get
Hooks
().
size
()
<<
" hooks of "
<<
op_type
<<
"
's input `"
<<
pair
.
first
<<
"`'s var `"
<<
var
->
Name
()
<<
"`."
;
VLOG
(
3
)
<<
"Call "
<<
var
->
Get
VariableWrapperHooks
().
size
()
<<
"
hooks of "
<<
op_type
<<
"'s input `"
<<
pair
.
first
<<
"`
's var `"
<<
var
->
Name
()
<<
"`
."
;
auto
tmp_var
=
var
;
for
(
const
auto
&
hook_pair
:
var
->
GetHooks
())
{
for
(
const
auto
&
hook_pair
:
var
->
Get
VariableWrapper
Hooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
}
(
*
tmp_ins_ptr
)[
pair
.
first
][
i
]
=
tmp_var
;
...
...
paddle/fluid/imperative/gradient_accumulator.cc
浏览文件 @
7ba85aca
...
...
@@ -467,14 +467,14 @@ void GradientAccumulator::CallGradientHooks() {
platform
::
errors
::
PreconditionNotMet
(
"Leaf Tensor's inner var "
"is not initialized when "
"call gradient hook."
));
if
(
var_
->
HasHook
())
{
VLOG
(
3
)
<<
"Call "
<<
var_
->
GetHooks
().
size
()
if
(
var_
->
Has
VariableWrapper
Hook
())
{
VLOG
(
3
)
<<
"Call "
<<
var_
->
Get
VariableWrapper
Hooks
().
size
()
<<
" hooks of leaf gradient accumulator's inner var `"
<<
var_
->
Name
()
<<
"`."
;
auto
tmp_var
=
inner_var_
;
VLOG
(
3
)
<<
"Input var "
<<
var_
->
Name
()
<<
"'s hook size - "
<<
var_
->
GetHooks
().
size
();
for
(
const
auto
&
hook_pair
:
var_
->
GetHooks
())
{
<<
var_
->
Get
VariableWrapper
Hooks
().
size
();
for
(
const
auto
&
hook_pair
:
var_
->
Get
VariableWrapper
Hooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
}
inner_var_
=
tmp_var
;
...
...
@@ -495,10 +495,10 @@ void GradientAccumulator::CallReduceHooks() {
"Only can call reduce hooks after the "
"gradient accumulation is completed in "
"current batch or across batchs."
));
if
(
var_
->
Has
Mutable
Hook
())
{
for
(
const
auto
&
hook
:
var_
->
Get
Mutable
Hooks
())
{
if
(
var_
->
Has
Void
Hook
())
{
for
(
const
auto
&
hook
:
var_
->
Get
Void
Hooks
())
{
VLOG
(
3
)
<<
"call gradient accumulator backward hooks."
;
(
*
hook
)(
var_
);
(
*
hook
)();
}
}
}
...
...
paddle/fluid/imperative/hooks.h
浏览文件 @
7ba85aca
...
...
@@ -23,32 +23,34 @@ namespace imperative {
class
VariableWrapper
;
/** [
Const VariableWrapper Hook: Pre hook functor of OpBase
]
/** [
VariableWrapper Hook
]
*
* @brief This hook functor is executed before the grad OpBase is executed,
* taking the input of the current grad OpBase as input, and
* executing python hooks (user-defined) or C++ hooks (developer-defined)
* to achieve the purpose of custom operations on the interior VarBase
* gradient.
* @brief This hook functor is executed before the grad OpBase is executed or
* after gradient accumulation completed in current batch.
* 1. For interior var, VariableWrapper Hook take the input of the
* current grad OpBase as input.
* 2. For leaf var, VariableWrapper Hook take the inner_var_ of
* GradientAccumulator as input.
*
* @note This hook functor will not change the input gradient VarBase.
* @note This hook functor will not change the input gradient VariableWrapper,
* but if you copy the input VariableWrapper and change the value of
* Variable in VariableWrapper, the value of input will also be changed,
* because they shared same PlaceHolder.
*
* @note [
Why need to be OpBase `PreHook`, why not `PostHook`?
]
* @note [
Why need to be OpBase `PreHook`, why not `PostHook`?
]
*
*
1.
We expect If set OpBase post hook, when the op executed end, the
* We expect If set OpBase post hook, when the op executed end, the
* op's output gradient may not be the final state, because it may need
* other op's gradient output to accumulated to it. But before op can
* be executed, the gradient output must have been accumulated to final
* value.
* 2. We don’t want the hook to change its input Tensor value, so now
* we can't call all hooks in GradAccumulator.
*
* @note [
Why only can be used for interior VarBase?
]
* @note [
Why Leaf gradient is special?
]
*
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
* GradVarBase has no next OpBase to executed, so if need to deal with
* the leaf GradVarBase,
cannot use this hook functor. For this case, we
*
deal with by other inplace hook metho
d.
* the leaf GradVarBase,
we should call hooks after gradient accumulation
*
complete
d.
*/
class
VariableWrapperHook
{
public:
...
...
@@ -57,34 +59,22 @@ class VariableWrapperHook {
const
std
::
shared_ptr
<
VariableWrapper
>&
var
)
=
0
;
};
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ]
*
* @brief This hook functor is the Hook that operates on the current
* gradientafter the GradientAccumulator has accumulated the gradient.
* Leaf GradVarBase has no next OpBase, if we want to register hook
* for it, we also need to wait until the leaf GradVarBase accumulation
* is completed, so we can add post hook to GradientAccumulator.
*
* @note This hook functor will change the grad VarBase value.
*
* @note Only allow leaf VarBase hold call this hook functor.
*/
class
InplaceVariableWrapperHook
{
public:
virtual
~
InplaceVariableWrapperHook
()
=
default
;
virtual
void
operator
()(
VariableWrapper
*
var
)
=
0
;
};
class
LambdaInplaceVariableWrapperHook
:
public
InplaceVariableWrapperHook
{
class
CppVariableWrapperHook
:
public
VariableWrapperHook
{
public:
explicit
LambdaInplaceVariableWrapperHook
(
std
::
function
<
void
(
VariableWrapper
*
)
>&&
fn
)
explicit
CppVariableWrapperHook
(
std
::
function
<
std
::
shared_ptr
<
VariableWrapper
>
(
const
std
::
shared_ptr
<
VariableWrapper
>&
)
>&&
fn
)
:
fn_
(
std
::
move
(
fn
))
{}
void
operator
()(
VariableWrapper
*
var
)
override
{
fn_
(
var
);
}
std
::
shared_ptr
<
VariableWrapper
>
operator
()(
const
std
::
shared_ptr
<
VariableWrapper
>&
var
)
override
{
return
fn_
(
var
);
}
private:
std
::
function
<
void
(
VariableWrapper
*
)
>
fn_
;
std
::
function
<
std
::
shared_ptr
<
VariableWrapper
>
(
const
std
::
shared_ptr
<
VariableWrapper
>&
)
>
fn_
;
};
}
// namespace imperative
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
7ba85aca
...
...
@@ -226,23 +226,25 @@ class VarBase {
void
BumpInplaceVersion
();
/* Hook related method: now only used for GradVarBase */
bool
Has
Hook
()
const
{
return
var_
->
Has
Hook
();
}
bool
Has
VariableWrapperHook
()
const
{
return
var_
->
HasVariableWrapper
Hook
();
}
int64_t
AddHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
return
var_
->
AddHook
(
int64_t
Add
VariableWrapper
Hook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
return
var_
->
Add
VariableWrapper
Hook
(
std
::
forward
<
std
::
shared_ptr
<
VariableWrapperHook
>>
(
hook
));
}
bool
RemoveHook
(
const
int64_t
&
hook_id
)
{
return
var_
->
RemoveHook
(
hook_id
);
}
bool
RemoveVariableWrapperHook
(
const
int64_t
&
hook_id
)
{
return
var_
->
RemoveVariableWrapperHook
(
hook_id
);
}
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetHooks
()
const
{
return
var_
->
GetHooks
();
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetVariableWrapperHooks
()
const
{
return
var_
->
Get
VariableWrapper
Hooks
();
}
void
Add
MutableHook
(
std
::
shared_ptr
<
InplaceVariableWrapperHook
>&&
hook
)
{
var_
->
Add
Mutable
Hook
(
std
::
forward
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>
(
hook
));
void
Add
VoidHook
(
std
::
shared_ptr
<
std
::
function
<
void
()
>
>&&
hook
)
{
var_
->
Add
Void
Hook
(
std
::
forward
<
std
::
shared_ptr
<
std
::
function
<
void
()
>
>>
(
hook
));
}
private:
...
...
paddle/fluid/imperative/reducer.cc
浏览文件 @
7ba85aca
...
...
@@ -310,9 +310,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
for
(
size_t
global_var_index
=
0
;
global_var_index
<
vars_
.
size
();
++
global_var_index
)
{
auto
var
=
vars_
[
global_var_index
];
var
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
([
=
](
VariableWrapper
*
grad
)
{
this
->
AddDistHook
(
global_var_index
);
}));
var
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
[
=
]()
{
this
->
AddDistHook
(
global_var_index
);
}));
var_index_map_
[
var
->
GradVarBase
()
->
SharedVar
().
get
()]
=
global_var_index
;
}
...
...
paddle/fluid/imperative/tests/test_hooks.cc
浏览文件 @
7ba85aca
...
...
@@ -37,6 +37,30 @@ namespace imperative {
using
vb_vector
=
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
;
using
var_pair
=
std
::
pair
<
std
::
string
,
vb_vector
>
;
std
::
shared_ptr
<
imperative
::
VariableWrapper
>
DoubleHook
(
const
std
::
shared_ptr
<
imperative
::
VariableWrapper
>&
var
)
{
// 1. create out var
auto
out_var
=
std
::
make_shared
<
imperative
::
VariableWrapper
>
(
var
->
Name
());
out_var
->
SetType
(
var
->
Type
());
out_var
->
SetDataType
(
var
->
DataType
());
out_var
->
SetForwardDataType
(
var
->
ForwardDataType
());
out_var
->
InnerSetOverridedStopGradient
(
var
->
InnerOverridedStopGradient
());
// 2. get input and output var's tensor
auto
*
out_tensor
=
out_var
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
tensor
=
var
->
Var
().
Get
<
framework
::
LoDTensor
>
();
out_tensor
->
Resize
(
tensor
.
dims
());
// 3. double calc
auto
*
data
=
tensor
.
data
<
float
>
();
auto
*
out_data
=
out_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int64_t
i
=
0
;
i
<
out_tensor
->
numel
();
++
i
)
{
out_data
[
i
]
=
data
[
i
]
*
2.0
;
}
return
out_var
;
}
TEST
(
TestHooks
,
TestGradVarLeafBackwardHook
)
{
// 1. prepare
Tracer
tracer
;
...
...
@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
framework
::
AttributeMap
mul_attr_map
;
mul_attr_map
[
"use_mkldnn"
]
=
false
;
// add GradAccumulatorPostHook
x
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
(
[
=
](
VariableWrapper
*
grad
)
{
auto
*
grad_tensor
=
grad
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
grad_tensor
->
numel
();
++
i
)
{
grad_tensor
->
mutable_data
<
float
>
(
place
)[
i
]
*=
2.0
;
}
}));
// add VariableWrapper hook
x
->
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
imperative
::
CppVariableWrapperHook
>
(
DoubleHook
));
// add Void hook
int64_t
hook_value
=
0
;
x
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
([
&
]()
{
hook_value
=
10
;
}));
// 2. forward
tracer
.
TraceOp
(
"mul"
,
ins
,
outs
,
mul_attr_map
,
place
,
true
);
...
...
@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
engine
.
Init
(
tensors
,
grad_tensors
);
engine
.
Execute
();
// verify VariableWrapper hook result
framework
::
LoDTensor
x_grad
;
framework
::
TensorCopySync
(
x
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
&
x_grad
);
for
(
int
i
=
0
;
i
<
x_grad
.
numel
();
++
i
)
{
ASSERT_EQ
(
x_grad
.
data
<
float
>
()[
i
],
8.0
);
}
// verify Void hook result
ASSERT_EQ
(
hook_value
,
10
);
framework
::
LoDTensor
y_grad
;
framework
::
TensorCopySync
(
y
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
...
...
@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
memory
::
Copy
(
place
,
mutable_z
,
place
,
src_data
.
data
(),
sizeof
(
float
)
*
src_data
.
size
());
// add ReduceBackwardHook
x
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
(
[
=
](
VariableWrapper
*
grad
)
{
auto
*
grad_tensor
=
grad
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
grad_tensor
->
numel
();
++
i
)
{
grad_tensor
->
mutable_data
<
float
>
(
place
)[
i
]
*=
2.0
;
}
}));
// add VariableWrapper hook
x
->
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
imperative
::
CppVariableWrapperHook
>
(
DoubleHook
));
// add Void hook
int64_t
hook_value
=
0
;
x
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
([
&
]()
{
hook_value
=
100
;
}));
// 2. forward
var_pair
x_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
x
));
...
...
@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
engine
.
Init
(
tensors
,
grad_tensors
);
engine
.
Execute
();
// verify VariableWrapper hook result
framework
::
LoDTensor
x_grad
;
framework
::
TensorCopySync
(
x
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
&
x_grad
);
for
(
int
i
=
0
;
i
<
x_grad
.
numel
();
++
i
)
{
ASSERT_EQ
(
x_grad
.
data
<
float
>
()[
i
],
16.0
);
}
// verify Void hook result
ASSERT_EQ
(
hook_value
,
100
);
framework
::
LoDTensor
y_grad
;
framework
::
TensorCopySync
(
y
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
...
...
paddle/fluid/imperative/variable_wrapper.h
浏览文件 @
7ba85aca
...
...
@@ -220,35 +220,35 @@ class VariableWrapper {
}
/* Hook related methods */
bool
Has
Hook
()
const
{
return
!
hooks_
.
empty
();
}
bool
Has
VariableWrapperHook
()
const
{
return
!
var_
hooks_
.
empty
();
}
bool
HasMutableHook
()
const
{
return
!
mutable_hooks_
.
empty
();
}
int64_t
AddHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
hooks_
.
emplace
(
next_hook_id_
,
std
::
move
(
hook
));
int64_t
AddVariableWrapperHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
var_hooks_
.
emplace
(
next_hook_id_
,
std
::
move
(
hook
));
return
next_hook_id_
++
;
}
bool
RemoveHook
(
const
int64_t
&
hook_id
)
{
auto
remove_cnt
=
hooks_
.
erase
(
hook_id
);
bool
Remove
VariableWrapper
Hook
(
const
int64_t
&
hook_id
)
{
auto
remove_cnt
=
var_
hooks_
.
erase
(
hook_id
);
if
(
remove_cnt
==
0
)
{
return
false
;
}
return
true
;
}
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetHooks
()
const
{
return
hooks_
;
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetVariableWrapperHooks
()
const
{
return
var_
hooks_
;
}
void
AddMutableHook
(
std
::
shared_ptr
<
InplaceVariableWrapperHook
>&&
hook
)
{
mutable_hooks_
.
emplace_back
(
std
::
move
(
hook
));
bool
HasVoidHook
()
const
{
return
!
void_hooks_
.
empty
();
}
void
AddVoidHook
(
std
::
shared_ptr
<
std
::
function
<
void
()
>>&&
hook
)
{
void_hooks_
.
emplace_back
(
std
::
move
(
hook
));
}
const
std
::
vector
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>&
GetMutableHooks
()
const
{
return
mutable
_hooks_
;
const
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>&
GetVoidHooks
()
const
{
return
void
_hooks_
;
}
private:
...
...
@@ -319,14 +319,19 @@ class VariableWrapper {
// isn't need
bool
is_empty_
{
false
};
// NOTE(chenweihang): only grad var
can
hold hooks now
// NOTE(chenweihang): only grad var
will
hold hooks now
int64_t
next_hook_id_
{
0
};
// Hooks used to register hook for grad var, support adding and removing,
// [ Hooks with VariableWrapper as input and output ]
// NOTE: Now registered for grad var, support adding and removing,
// key is the accumulated int64_t value
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>
hooks_
;
// Hooks executed after the execution of the entire backward process is over,
// currently only supported for reducing in distributed training
std
::
vector
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>
mutable_hooks_
;
// NOTE: Var hook need to support removing, so need hook id
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>
var_hooks_
;
// [ Hooks without input and output ]
// NOTE: Now registered after the execution of the entire backward
// process is over, currently only used for reducing in distributed
// training
// NOTE: Now no need to support remove void hook
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>
void_hooks_
;
};
}
// namespace imperative
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
7ba85aca
...
...
@@ -1069,20 +1069,58 @@ void BindImperative(py::module *m_ptr) {
.
def
(
"_register_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
self
.
HasGradVar
(),
true
,
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register hook on a tensor without gradient."
));
return
self
.
GradVarBase
()
->
AddHook
(
"Cannot register gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
PyVariableWrapperHook
>
(
hook
.
ptr
()));
})
.
def
(
"_remove_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
int64_t
hook_id
)
{
PADDLE_ENFORCE_EQ
(
self
.
HasGradVar
(),
true
,
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot remove hook on a tensor without gradient."
));
return
self
.
GradVarBase
()
->
RemoveHook
(
hook_id
);
"Cannot remove gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
RemoveVariableWrapperHook
(
hook_id
);
})
.
def
(
"_register_backward_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
self
.
IsLeaf
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Only can register backward hook for leaf Tensor."
));
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register backward hook on a Tensor that stop "
"gradient or without gradient."
));
auto
py_func
=
PyObjectCast
<
std
::
function
<
void
()
>>
(
hook
.
ptr
());
self
.
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
py_func
));
},
R"DOC(
Registers a backward hook for current Tensor.
This hook will be called every time the gradient of current Tensor has been fully calculated.
There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batchs,
but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
completed in current batch.
2. This backward hook function should have the following signature:
hook() -> None
It requires no input and no return value.
Args:
hook(function): A backward hook to be registered for Tensor.gradient
Returns:
None
)DOC"
)
.
def
(
"cpu"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
if
(
platform
::
is_cpu_place
(
self
->
Place
()))
{
...
...
@@ -1301,21 +1339,15 @@ void BindImperative(py::module *m_ptr) {
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property_readonly
(
"shape"
,
.
def_property_readonly
(
"shape"
,
[](
imperative
::
VarBase
&
self
)
{
if
(
self
.
Var
().
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
LoDTensor
>
()
.
dims
());
}
else
if
(
self
.
Var
()
.
IsType
<
framework
::
SelectedRows
>
())
{
self
.
Var
().
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
self
.
Var
().
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
SelectedRows
>
()
.
value
()
.
dims
());
self
.
Var
().
Get
<
framework
::
SelectedRows
>
().
value
().
dims
());
}
else
{
VLOG
(
2
)
<<
"It is meaningless to get shape of "
"variable type "
...
...
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
浏览文件 @
7ba85aca
...
...
@@ -178,8 +178,9 @@ class TestTensorRegisterHook(unittest.TestCase):
# register hook and removed
run_double_hook_for_leaf_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_for_accumulated_grad
(
self
):
def
run_double_hook_for_accumulated_grad
(
double_hook
,
removed
=
False
):
def
test_hook_for_accumulated_grad_interior_var
(
self
):
def
run_double_hook_for_accumulated_grad_interior_var
(
double_hook
,
removed
=
False
):
for
device
in
self
.
devices
:
paddle
.
set_device
(
device
)
...
...
@@ -227,9 +228,50 @@ class TestTensorRegisterHook(unittest.TestCase):
if
not
removed
else
base_grad
))
# register hook
run_double_hook_for_accumulated_grad
(
lambda
grad
:
grad
*
2
)
run_double_hook_for_accumulated_grad
_interior_var
(
lambda
grad
:
grad
*
2
)
# register hook and removed
run_double_hook_for_accumulated_grad
(
run_double_hook_for_accumulated_grad_interior_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_for_accumulated_grad_leaf_var
(
self
):
def
run_double_hook_for_accumulated_grad_leaf_var
(
double_hook
,
removed
=
False
):
for
device
in
self
.
devices
:
paddle
.
set_device
(
device
)
x
=
paddle
.
to_tensor
([
0.
,
1.
,
2.
,
4.
])
x
.
stop_gradient
=
False
helper
=
x
.
register_hook
(
double_hook
)
y
=
paddle
.
to_tensor
([
4.
,
5.
,
6.
,
7.
])
z
=
paddle
.
to_tensor
([
1.
,
2.
,
3.
,
4.
])
y
.
stop_gradient
=
False
z
.
stop_gradient
=
False
o1
=
x
+
y
o2
=
x
+
z
o1
.
stop_gradient
=
False
o2
.
stop_gradient
=
False
o
=
o1
.
matmul
(
o2
)
# remove hook before backward
if
removed
:
helper
.
remove
()
o
.
backward
()
base_grad
=
np
.
array
([
5.
,
9.
,
13.
,
19.
])
# x.grad is changed by x.hook
self
.
assertTrue
(
np
.
array_equal
(
x
.
grad
,
base_grad
*
2
if
not
removed
else
base_grad
))
# register hook
run_double_hook_for_accumulated_grad_leaf_var
(
lambda
grad
:
grad
*
2
)
# register hook and removed
run_double_hook_for_accumulated_grad_leaf_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_in_model
(
self
):
...
...
@@ -409,5 +451,54 @@ class TestTensorRegisterHook(unittest.TestCase):
x
.
register_hook
(
lambda
grad
:
grad
*
2
)
HOOK_INIT_VALUE
=
10
HOOK_IS_CALLED
=
False
def
global_void_hook
():
global
HOOK_INIT_VALUE
global
HOOK_IS_CALLED
HOOK_INIT_VALUE
*=
2
HOOK_IS_CALLED
=
True
class
TestTensorRegisterBackwardHook
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
self
.
devices
.
append
(
"gpu"
)
def
test_register_backward_hook
(
self
):
global
HOOK_INIT_VALUE
global
HOOK_IS_CALLED
for
device
in
self
.
devices
:
x
=
paddle
.
to_tensor
(
5.
,
stop_gradient
=
False
)
x
.
_register_backward_hook
(
global_void_hook
)
for
i
in
range
(
5
):
y
=
paddle
.
pow
(
x
,
4.0
)
y
.
backward
()
self
.
assertEqual
(
HOOK_INIT_VALUE
,
320
)
self
.
assertTrue
(
HOOK_IS_CALLED
)
# reset initial value
HOOK_INIT_VALUE
=
10
HOOK_IS_CALLED
=
False
def
test_register_backward_hook_for_interior_var
(
self
):
x
=
paddle
.
to_tensor
(
5.
,
stop_gradient
=
False
)
y
=
paddle
.
pow
(
x
,
4.0
)
with
self
.
assertRaises
(
ValueError
):
y
.
_register_backward_hook
(
global_void_hook
)
def
test_register_backward_hook_for_var_without_gradient
(
self
):
x
=
paddle
.
to_tensor
(
5.
)
y
=
paddle
.
pow
(
x
,
4.0
)
with
self
.
assertRaises
(
ValueError
):
x
.
_register_backward_hook
(
global_void_hook
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录