Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
7ba85aca
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7ba85aca
编写于
4月 14, 2021
作者:
C
Chen Weihang
提交者:
GitHub
4月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add inner register backward hook method for Tensor (#32171)
* add register backward hook method * add leaf grad accumullated test
上级
f3e49c40
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
280 addition
and
135 deletion
+280
-135
paddle/fluid/imperative/basic_engine.cc
paddle/fluid/imperative/basic_engine.cc
+5
-5
paddle/fluid/imperative/gradient_accumulator.cc
paddle/fluid/imperative/gradient_accumulator.cc
+7
-7
paddle/fluid/imperative/hooks.h
paddle/fluid/imperative/hooks.h
+27
-37
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+12
-10
paddle/fluid/imperative/reducer.cc
paddle/fluid/imperative/reducer.cc
+2
-3
paddle/fluid/imperative/tests/test_hooks.cc
paddle/fluid/imperative/tests/test_hooks.cc
+46
-20
paddle/fluid/imperative/variable_wrapper.h
paddle/fluid/imperative/variable_wrapper.h
+26
-21
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+60
-28
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
...paddle/fluid/tests/unittests/test_tensor_register_hook.py
+95
-4
未找到文件。
paddle/fluid/imperative/basic_engine.cc
浏览文件 @
7ba85aca
...
...
@@ -284,15 +284,15 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
for
(
const
auto
&
pair
:
bwd_ins
)
{
for
(
size_t
i
=
0
;
i
<
pair
.
second
.
size
();
++
i
)
{
auto
&
var
=
pair
.
second
[
i
];
if
(
var
->
HasHook
())
{
if
(
var
->
Has
VariableWrapper
Hook
())
{
if
(
tmp_ins_ptr
==
nullptr
)
{
tmp_ins_ptr
=
std
::
make_shared
<
NameVarMap
<
VariableWrapper
>>
(
bwd_ins
);
}
VLOG
(
3
)
<<
"Call "
<<
var
->
Get
Hooks
().
size
()
<<
" hooks of "
<<
op_type
<<
"
's input `"
<<
pair
.
first
<<
"`'s var `"
<<
var
->
Name
()
<<
"`."
;
VLOG
(
3
)
<<
"Call "
<<
var
->
Get
VariableWrapperHooks
().
size
()
<<
"
hooks of "
<<
op_type
<<
"'s input `"
<<
pair
.
first
<<
"`
's var `"
<<
var
->
Name
()
<<
"`
."
;
auto
tmp_var
=
var
;
for
(
const
auto
&
hook_pair
:
var
->
GetHooks
())
{
for
(
const
auto
&
hook_pair
:
var
->
Get
VariableWrapper
Hooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
}
(
*
tmp_ins_ptr
)[
pair
.
first
][
i
]
=
tmp_var
;
...
...
paddle/fluid/imperative/gradient_accumulator.cc
浏览文件 @
7ba85aca
...
...
@@ -467,14 +467,14 @@ void GradientAccumulator::CallGradientHooks() {
platform
::
errors
::
PreconditionNotMet
(
"Leaf Tensor's inner var "
"is not initialized when "
"call gradient hook."
));
if
(
var_
->
HasHook
())
{
VLOG
(
3
)
<<
"Call "
<<
var_
->
GetHooks
().
size
()
if
(
var_
->
Has
VariableWrapper
Hook
())
{
VLOG
(
3
)
<<
"Call "
<<
var_
->
Get
VariableWrapper
Hooks
().
size
()
<<
" hooks of leaf gradient accumulator's inner var `"
<<
var_
->
Name
()
<<
"`."
;
auto
tmp_var
=
inner_var_
;
VLOG
(
3
)
<<
"Input var "
<<
var_
->
Name
()
<<
"'s hook size - "
<<
var_
->
GetHooks
().
size
();
for
(
const
auto
&
hook_pair
:
var_
->
GetHooks
())
{
<<
var_
->
Get
VariableWrapper
Hooks
().
size
();
for
(
const
auto
&
hook_pair
:
var_
->
Get
VariableWrapper
Hooks
())
{
tmp_var
=
(
*
hook_pair
.
second
)(
tmp_var
);
}
inner_var_
=
tmp_var
;
...
...
@@ -495,10 +495,10 @@ void GradientAccumulator::CallReduceHooks() {
"Only can call reduce hooks after the "
"gradient accumulation is completed in "
"current batch or across batchs."
));
if
(
var_
->
Has
Mutable
Hook
())
{
for
(
const
auto
&
hook
:
var_
->
Get
Mutable
Hooks
())
{
if
(
var_
->
Has
Void
Hook
())
{
for
(
const
auto
&
hook
:
var_
->
Get
Void
Hooks
())
{
VLOG
(
3
)
<<
"call gradient accumulator backward hooks."
;
(
*
hook
)(
var_
);
(
*
hook
)();
}
}
}
...
...
paddle/fluid/imperative/hooks.h
浏览文件 @
7ba85aca
...
...
@@ -23,32 +23,34 @@ namespace imperative {
class
VariableWrapper
;
/** [
Const VariableWrapper Hook: Pre hook functor of OpBase
]
/** [
VariableWrapper Hook
]
*
* @brief This hook functor is executed before the grad OpBase is executed,
* taking the input of the current grad OpBase as input, and
* executing python hooks (user-defined) or C++ hooks (developer-defined)
* to achieve the purpose of custom operations on the interior VarBase
* gradient.
* @brief This hook functor is executed before the grad OpBase is executed or
* after gradient accumulation completed in current batch.
* 1. For interior var, VariableWrapper Hook take the input of the
* current grad OpBase as input.
* 2. For leaf var, VariableWrapper Hook take the inner_var_ of
* GradientAccumulator as input.
*
* @note This hook functor will not change the input gradient VarBase.
* @note This hook functor will not change the input gradient VariableWrapper,
* but if you copy the input VariableWrapper and change the value of
* Variable in VariableWrapper, the value of input will also be changed,
* because they shared same PlaceHolder.
*
* @note [
Why need to be OpBase `PreHook`, why not `PostHook`?
]
* @note [
Why need to be OpBase `PreHook`, why not `PostHook`?
]
*
*
1.
We expect If set OpBase post hook, when the op executed end, the
* We expect If set OpBase post hook, when the op executed end, the
* op's output gradient may not be the final state, because it may need
* other op's gradient output to accumulated to it. But before op can
* be executed, the gradient output must have been accumulated to final
* value.
* 2. We don’t want the hook to change its input Tensor value, so now
* we can't call all hooks in GradAccumulator.
*
* @note [
Why only can be used for interior VarBase?
]
* @note [
Why Leaf gradient is special?
]
*
* Because the leaf VarBase's GradVarBase has no GradOpNode, so leaf
* GradVarBase has no next OpBase to executed, so if need to deal with
* the leaf GradVarBase,
cannot use this hook functor. For this case, we
*
deal with by other inplace hook metho
d.
* the leaf GradVarBase,
we should call hooks after gradient accumulation
*
complete
d.
*/
class
VariableWrapperHook
{
public:
...
...
@@ -57,34 +59,22 @@ class VariableWrapperHook {
const
std
::
shared_ptr
<
VariableWrapper
>&
var
)
=
0
;
};
/** [ Inplace VariableWrapper Hook: Post hook functor of GradAccumulator ]
*
* @brief This hook functor is the Hook that operates on the current
* gradientafter the GradientAccumulator has accumulated the gradient.
* Leaf GradVarBase has no next OpBase, if we want to register hook
* for it, we also need to wait until the leaf GradVarBase accumulation
* is completed, so we can add post hook to GradientAccumulator.
*
* @note This hook functor will change the grad VarBase value.
*
* @note Only allow leaf VarBase hold call this hook functor.
*/
class
InplaceVariableWrapperHook
{
public:
virtual
~
InplaceVariableWrapperHook
()
=
default
;
virtual
void
operator
()(
VariableWrapper
*
var
)
=
0
;
};
class
LambdaInplaceVariableWrapperHook
:
public
InplaceVariableWrapperHook
{
class
CppVariableWrapperHook
:
public
VariableWrapperHook
{
public:
explicit
LambdaInplaceVariableWrapperHook
(
std
::
function
<
void
(
VariableWrapper
*
)
>&&
fn
)
explicit
CppVariableWrapperHook
(
std
::
function
<
std
::
shared_ptr
<
VariableWrapper
>
(
const
std
::
shared_ptr
<
VariableWrapper
>&
)
>&&
fn
)
:
fn_
(
std
::
move
(
fn
))
{}
void
operator
()(
VariableWrapper
*
var
)
override
{
fn_
(
var
);
}
std
::
shared_ptr
<
VariableWrapper
>
operator
()(
const
std
::
shared_ptr
<
VariableWrapper
>&
var
)
override
{
return
fn_
(
var
);
}
private:
std
::
function
<
void
(
VariableWrapper
*
)
>
fn_
;
std
::
function
<
std
::
shared_ptr
<
VariableWrapper
>
(
const
std
::
shared_ptr
<
VariableWrapper
>&
)
>
fn_
;
};
}
// namespace imperative
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
7ba85aca
...
...
@@ -226,23 +226,25 @@ class VarBase {
void
BumpInplaceVersion
();
/* Hook related method: now only used for GradVarBase */
bool
Has
Hook
()
const
{
return
var_
->
Has
Hook
();
}
bool
Has
VariableWrapperHook
()
const
{
return
var_
->
HasVariableWrapper
Hook
();
}
int64_t
AddHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
return
var_
->
AddHook
(
int64_t
Add
VariableWrapper
Hook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
return
var_
->
Add
VariableWrapper
Hook
(
std
::
forward
<
std
::
shared_ptr
<
VariableWrapperHook
>>
(
hook
));
}
bool
RemoveHook
(
const
int64_t
&
hook_id
)
{
return
var_
->
RemoveHook
(
hook_id
);
}
bool
RemoveVariableWrapperHook
(
const
int64_t
&
hook_id
)
{
return
var_
->
RemoveVariableWrapperHook
(
hook_id
);
}
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetHooks
()
const
{
return
var_
->
GetHooks
();
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetVariableWrapperHooks
()
const
{
return
var_
->
Get
VariableWrapper
Hooks
();
}
void
Add
MutableHook
(
std
::
shared_ptr
<
InplaceVariableWrapperHook
>&&
hook
)
{
var_
->
Add
Mutable
Hook
(
std
::
forward
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>
(
hook
));
void
Add
VoidHook
(
std
::
shared_ptr
<
std
::
function
<
void
()
>
>&&
hook
)
{
var_
->
Add
Void
Hook
(
std
::
forward
<
std
::
shared_ptr
<
std
::
function
<
void
()
>
>>
(
hook
));
}
private:
...
...
paddle/fluid/imperative/reducer.cc
浏览文件 @
7ba85aca
...
...
@@ -310,9 +310,8 @@ Reducer::Reducer(const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
for
(
size_t
global_var_index
=
0
;
global_var_index
<
vars_
.
size
();
++
global_var_index
)
{
auto
var
=
vars_
[
global_var_index
];
var
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
([
=
](
VariableWrapper
*
grad
)
{
this
->
AddDistHook
(
global_var_index
);
}));
var
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
[
=
]()
{
this
->
AddDistHook
(
global_var_index
);
}));
var_index_map_
[
var
->
GradVarBase
()
->
SharedVar
().
get
()]
=
global_var_index
;
}
...
...
paddle/fluid/imperative/tests/test_hooks.cc
浏览文件 @
7ba85aca
...
...
@@ -37,6 +37,30 @@ namespace imperative {
using
vb_vector
=
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
;
using
var_pair
=
std
::
pair
<
std
::
string
,
vb_vector
>
;
std
::
shared_ptr
<
imperative
::
VariableWrapper
>
DoubleHook
(
const
std
::
shared_ptr
<
imperative
::
VariableWrapper
>&
var
)
{
// 1. create out var
auto
out_var
=
std
::
make_shared
<
imperative
::
VariableWrapper
>
(
var
->
Name
());
out_var
->
SetType
(
var
->
Type
());
out_var
->
SetDataType
(
var
->
DataType
());
out_var
->
SetForwardDataType
(
var
->
ForwardDataType
());
out_var
->
InnerSetOverridedStopGradient
(
var
->
InnerOverridedStopGradient
());
// 2. get input and output var's tensor
auto
*
out_tensor
=
out_var
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
&
tensor
=
var
->
Var
().
Get
<
framework
::
LoDTensor
>
();
out_tensor
->
Resize
(
tensor
.
dims
());
// 3. double calc
auto
*
data
=
tensor
.
data
<
float
>
();
auto
*
out_data
=
out_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int64_t
i
=
0
;
i
<
out_tensor
->
numel
();
++
i
)
{
out_data
[
i
]
=
data
[
i
]
*
2.0
;
}
return
out_var
;
}
TEST
(
TestHooks
,
TestGradVarLeafBackwardHook
)
{
// 1. prepare
Tracer
tracer
;
...
...
@@ -73,16 +97,14 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
framework
::
AttributeMap
mul_attr_map
;
mul_attr_map
[
"use_mkldnn"
]
=
false
;
// add GradAccumulatorPostHook
x
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
(
[
=
](
VariableWrapper
*
grad
)
{
auto
*
grad_tensor
=
grad
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
grad_tensor
->
numel
();
++
i
)
{
grad_tensor
->
mutable_data
<
float
>
(
place
)[
i
]
*=
2.0
;
}
}));
// add VariableWrapper hook
x
->
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
imperative
::
CppVariableWrapperHook
>
(
DoubleHook
));
// add Void hook
int64_t
hook_value
=
0
;
x
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
([
&
]()
{
hook_value
=
10
;
}));
// 2. forward
tracer
.
TraceOp
(
"mul"
,
ins
,
outs
,
mul_attr_map
,
place
,
true
);
...
...
@@ -98,12 +120,15 @@ TEST(TestHooks, TestGradVarLeafBackwardHook) {
engine
.
Init
(
tensors
,
grad_tensors
);
engine
.
Execute
();
// verify VariableWrapper hook result
framework
::
LoDTensor
x_grad
;
framework
::
TensorCopySync
(
x
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
&
x_grad
);
for
(
int
i
=
0
;
i
<
x_grad
.
numel
();
++
i
)
{
ASSERT_EQ
(
x_grad
.
data
<
float
>
()[
i
],
8.0
);
}
// verify Void hook result
ASSERT_EQ
(
hook_value
,
10
);
framework
::
LoDTensor
y_grad
;
framework
::
TensorCopySync
(
y
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
...
...
@@ -152,16 +177,14 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
memory
::
Copy
(
place
,
mutable_z
,
place
,
src_data
.
data
(),
sizeof
(
float
)
*
src_data
.
size
());
// add ReduceBackwardHook
x
->
GradVarBase
()
->
AddMutableHook
(
std
::
make_shared
<
LambdaInplaceVariableWrapperHook
>
(
[
=
](
VariableWrapper
*
grad
)
{
auto
*
grad_tensor
=
grad
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
grad_tensor
->
numel
();
++
i
)
{
grad_tensor
->
mutable_data
<
float
>
(
place
)[
i
]
*=
2.0
;
}
}));
// add VariableWrapper hook
x
->
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
imperative
::
CppVariableWrapperHook
>
(
DoubleHook
));
// add Void hook
int64_t
hook_value
=
0
;
x
->
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
([
&
]()
{
hook_value
=
100
;
}));
// 2. forward
var_pair
x_pair
=
var_pair
(
"X"
,
vb_vector
(
1
,
x
));
...
...
@@ -199,12 +222,15 @@ void GradVarLeafBackwardHookWithGradAccmulatedTest() {
engine
.
Init
(
tensors
,
grad_tensors
);
engine
.
Execute
();
// verify VariableWrapper hook result
framework
::
LoDTensor
x_grad
;
framework
::
TensorCopySync
(
x
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
&
x_grad
);
for
(
int
i
=
0
;
i
<
x_grad
.
numel
();
++
i
)
{
ASSERT_EQ
(
x_grad
.
data
<
float
>
()[
i
],
16.0
);
}
// verify Void hook result
ASSERT_EQ
(
hook_value
,
100
);
framework
::
LoDTensor
y_grad
;
framework
::
TensorCopySync
(
y
->
GradVar
().
Get
<
framework
::
LoDTensor
>
(),
place
,
...
...
paddle/fluid/imperative/variable_wrapper.h
浏览文件 @
7ba85aca
...
...
@@ -220,35 +220,35 @@ class VariableWrapper {
}
/* Hook related methods */
bool
Has
Hook
()
const
{
return
!
hooks_
.
empty
();
}
bool
Has
VariableWrapperHook
()
const
{
return
!
var_
hooks_
.
empty
();
}
bool
HasMutableHook
()
const
{
return
!
mutable_hooks_
.
empty
();
}
int64_t
AddHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
hooks_
.
emplace
(
next_hook_id_
,
std
::
move
(
hook
));
int64_t
AddVariableWrapperHook
(
std
::
shared_ptr
<
VariableWrapperHook
>&&
hook
)
{
var_hooks_
.
emplace
(
next_hook_id_
,
std
::
move
(
hook
));
return
next_hook_id_
++
;
}
bool
RemoveHook
(
const
int64_t
&
hook_id
)
{
auto
remove_cnt
=
hooks_
.
erase
(
hook_id
);
bool
Remove
VariableWrapper
Hook
(
const
int64_t
&
hook_id
)
{
auto
remove_cnt
=
var_
hooks_
.
erase
(
hook_id
);
if
(
remove_cnt
==
0
)
{
return
false
;
}
return
true
;
}
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetHooks
()
const
{
return
hooks_
;
const
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>&
GetVariableWrapperHooks
()
const
{
return
var_
hooks_
;
}
void
AddMutableHook
(
std
::
shared_ptr
<
InplaceVariableWrapperHook
>&&
hook
)
{
mutable_hooks_
.
emplace_back
(
std
::
move
(
hook
));
bool
HasVoidHook
()
const
{
return
!
void_hooks_
.
empty
();
}
void
AddVoidHook
(
std
::
shared_ptr
<
std
::
function
<
void
()
>>&&
hook
)
{
void_hooks_
.
emplace_back
(
std
::
move
(
hook
));
}
const
std
::
vector
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>&
GetMutableHooks
()
const
{
return
mutable
_hooks_
;
const
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>&
GetVoidHooks
()
const
{
return
void
_hooks_
;
}
private:
...
...
@@ -319,14 +319,19 @@ class VariableWrapper {
// isn't need
bool
is_empty_
{
false
};
// NOTE(chenweihang): only grad var
can
hold hooks now
// NOTE(chenweihang): only grad var
will
hold hooks now
int64_t
next_hook_id_
{
0
};
// Hooks used to register hook for grad var, support adding and removing,
// [ Hooks with VariableWrapper as input and output ]
// NOTE: Now registered for grad var, support adding and removing,
// key is the accumulated int64_t value
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>
hooks_
;
// Hooks executed after the execution of the entire backward process is over,
// currently only supported for reducing in distributed training
std
::
vector
<
std
::
shared_ptr
<
InplaceVariableWrapperHook
>>
mutable_hooks_
;
// NOTE: Var hook need to support removing, so need hook id
std
::
map
<
int64_t
,
std
::
shared_ptr
<
VariableWrapperHook
>>
var_hooks_
;
// [ Hooks without input and output ]
// NOTE: Now registered after the execution of the entire backward
// process is over, currently only used for reducing in distributed
// training
// NOTE: Now no need to support remove void hook
std
::
vector
<
std
::
shared_ptr
<
std
::
function
<
void
()
>>>
void_hooks_
;
};
}
// namespace imperative
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
7ba85aca
...
...
@@ -1069,20 +1069,58 @@ void BindImperative(py::module *m_ptr) {
.
def
(
"_register_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
self
.
HasGradVar
(),
true
,
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register hook on a tensor without gradient."
));
return
self
.
GradVarBase
()
->
AddHook
(
"Cannot register gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
PyVariableWrapperHook
>
(
hook
.
ptr
()));
})
.
def
(
"_remove_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
int64_t
hook_id
)
{
PADDLE_ENFORCE_EQ
(
self
.
HasGradVar
(),
true
,
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot remove hook on a tensor without gradient."
));
return
self
.
GradVarBase
()
->
RemoveHook
(
hook_id
);
"Cannot remove gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
RemoveVariableWrapperHook
(
hook_id
);
})
.
def
(
"_register_backward_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
self
.
IsLeaf
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Only can register backward hook for leaf Tensor."
));
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register backward hook on a Tensor that stop "
"gradient or without gradient."
));
auto
py_func
=
PyObjectCast
<
std
::
function
<
void
()
>>
(
hook
.
ptr
());
self
.
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
py_func
));
},
R"DOC(
Registers a backward hook for current Tensor.
This hook will be called every time the gradient of current Tensor has been fully calculated.
There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batchs,
but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
completed in current batch.
2. This backward hook function should have the following signature:
hook() -> None
It requires no input and no return value.
Args:
hook(function): A backward hook to be registered for Tensor.gradient
Returns:
None
)DOC"
)
.
def
(
"cpu"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
if
(
platform
::
is_cpu_place
(
self
->
Place
()))
{
...
...
@@ -1301,21 +1339,15 @@ void BindImperative(py::module *m_ptr) {
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property_readonly
(
"shape"
,
.
def_property_readonly
(
"shape"
,
[](
imperative
::
VarBase
&
self
)
{
if
(
self
.
Var
().
IsType
<
framework
::
LoDTensor
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
LoDTensor
>
()
.
dims
());
}
else
if
(
self
.
Var
()
.
IsType
<
framework
::
SelectedRows
>
())
{
self
.
Var
().
Get
<
framework
::
LoDTensor
>
().
dims
());
}
else
if
(
self
.
Var
().
IsType
<
framework
::
SelectedRows
>
())
{
return
framework
::
vectorize
<
int
>
(
self
.
Var
()
.
Get
<
framework
::
SelectedRows
>
()
.
value
()
.
dims
());
self
.
Var
().
Get
<
framework
::
SelectedRows
>
().
value
().
dims
());
}
else
{
VLOG
(
2
)
<<
"It is meaningless to get shape of "
"variable type "
...
...
python/paddle/fluid/tests/unittests/test_tensor_register_hook.py
浏览文件 @
7ba85aca
...
...
@@ -178,8 +178,9 @@ class TestTensorRegisterHook(unittest.TestCase):
# register hook and removed
run_double_hook_for_leaf_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_for_accumulated_grad
(
self
):
def
run_double_hook_for_accumulated_grad
(
double_hook
,
removed
=
False
):
def
test_hook_for_accumulated_grad_interior_var
(
self
):
def
run_double_hook_for_accumulated_grad_interior_var
(
double_hook
,
removed
=
False
):
for
device
in
self
.
devices
:
paddle
.
set_device
(
device
)
...
...
@@ -227,9 +228,50 @@ class TestTensorRegisterHook(unittest.TestCase):
if
not
removed
else
base_grad
))
# register hook
run_double_hook_for_accumulated_grad
(
lambda
grad
:
grad
*
2
)
run_double_hook_for_accumulated_grad
_interior_var
(
lambda
grad
:
grad
*
2
)
# register hook and removed
run_double_hook_for_accumulated_grad
(
run_double_hook_for_accumulated_grad_interior_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_for_accumulated_grad_leaf_var
(
self
):
def
run_double_hook_for_accumulated_grad_leaf_var
(
double_hook
,
removed
=
False
):
for
device
in
self
.
devices
:
paddle
.
set_device
(
device
)
x
=
paddle
.
to_tensor
([
0.
,
1.
,
2.
,
4.
])
x
.
stop_gradient
=
False
helper
=
x
.
register_hook
(
double_hook
)
y
=
paddle
.
to_tensor
([
4.
,
5.
,
6.
,
7.
])
z
=
paddle
.
to_tensor
([
1.
,
2.
,
3.
,
4.
])
y
.
stop_gradient
=
False
z
.
stop_gradient
=
False
o1
=
x
+
y
o2
=
x
+
z
o1
.
stop_gradient
=
False
o2
.
stop_gradient
=
False
o
=
o1
.
matmul
(
o2
)
# remove hook before backward
if
removed
:
helper
.
remove
()
o
.
backward
()
base_grad
=
np
.
array
([
5.
,
9.
,
13.
,
19.
])
# x.grad is changed by x.hook
self
.
assertTrue
(
np
.
array_equal
(
x
.
grad
,
base_grad
*
2
if
not
removed
else
base_grad
))
# register hook
run_double_hook_for_accumulated_grad_leaf_var
(
lambda
grad
:
grad
*
2
)
# register hook and removed
run_double_hook_for_accumulated_grad_leaf_var
(
lambda
grad
:
grad
*
2
,
removed
=
True
)
def
test_hook_in_model
(
self
):
...
...
@@ -409,5 +451,54 @@ class TestTensorRegisterHook(unittest.TestCase):
x
.
register_hook
(
lambda
grad
:
grad
*
2
)
HOOK_INIT_VALUE
=
10
HOOK_IS_CALLED
=
False
def
global_void_hook
():
global
HOOK_INIT_VALUE
global
HOOK_IS_CALLED
HOOK_INIT_VALUE
*=
2
HOOK_IS_CALLED
=
True
class
TestTensorRegisterBackwardHook
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
devices
=
[
"cpu"
]
if
paddle
.
is_compiled_with_cuda
():
self
.
devices
.
append
(
"gpu"
)
def
test_register_backward_hook
(
self
):
global
HOOK_INIT_VALUE
global
HOOK_IS_CALLED
for
device
in
self
.
devices
:
x
=
paddle
.
to_tensor
(
5.
,
stop_gradient
=
False
)
x
.
_register_backward_hook
(
global_void_hook
)
for
i
in
range
(
5
):
y
=
paddle
.
pow
(
x
,
4.0
)
y
.
backward
()
self
.
assertEqual
(
HOOK_INIT_VALUE
,
320
)
self
.
assertTrue
(
HOOK_IS_CALLED
)
# reset initial value
HOOK_INIT_VALUE
=
10
HOOK_IS_CALLED
=
False
def
test_register_backward_hook_for_interior_var
(
self
):
x
=
paddle
.
to_tensor
(
5.
,
stop_gradient
=
False
)
y
=
paddle
.
pow
(
x
,
4.0
)
with
self
.
assertRaises
(
ValueError
):
y
.
_register_backward_hook
(
global_void_hook
)
def
test_register_backward_hook_for_var_without_gradient
(
self
):
x
=
paddle
.
to_tensor
(
5.
)
y
=
paddle
.
pow
(
x
,
4.0
)
with
self
.
assertRaises
(
ValueError
):
x
.
_register_backward_hook
(
global_void_hook
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录