Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
35c7c835
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
35c7c835
编写于
5月 12, 2022
作者:
J
Jiabin Yang
提交者:
GitHub
5月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Eager] Remove full reserved strategy (#42690)
* remove full reserved strategy * fix inplace error
上级
6e90ba1b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
51 addition
and
84 deletion
+51
-84
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+10
-21
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+8
-8
paddle/fluid/eager/tensor_wrapper.h
paddle/fluid/eager/tensor_wrapper.h
+25
-51
paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc
...d/eager/tests/data_structure_tests/tensor_wrapper_test.cc
+8
-4
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
35c7c835
...
@@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1156,28 +1156,20 @@ static std::string GenerateGradNodeCreationContent(
for
(
const
auto
&
iter
:
op_base_infos
)
{
for
(
const
auto
&
iter
:
op_base_infos
)
{
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
=
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
=
iter
.
GetGradInsFwdSlotnameMap
();
iter
.
GetGradInsFwdSlotnameMap
();
const
std
::
unordered_set
<
std
::
string
>&
no_need_buffer_ins
=
iter
.
GetNoNeedBufferInputs
();
for
(
auto
&
kv
:
grad_ins_fwd_slotname_map
)
{
for
(
auto
&
kv
:
grad_ins_fwd_slotname_map
)
{
const
std
::
string
&
tensor_wrapper_name
=
kv
.
second
;
const
std
::
string
&
tensor_wrapper_name
=
kv
.
second
;
std
::
string
full_reserved
=
"false"
;
if
(
fwd_outputs_name_pos_map
.
find
(
tensor_wrapper_name
)
==
fwd_outputs_name_pos_map
.
end
()
&&
!
no_need_buffer_ins
.
count
(
tensor_wrapper_name
))
{
full_reserved
=
"true"
;
}
const
char
*
SET_TENSOR_WRAPPER_TEMPLATE
=
const
char
*
SET_TENSOR_WRAPPER_TEMPLATE
=
" grad_node->SetTensorWrapper%s(%s
, %s
);
\n
"
;
" grad_node->SetTensorWrapper%s(%s);
\n
"
;
// Replace output directly with input in inplace op.
// Replace output directly with input in inplace op.
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
tensor_wrapper_name
))
{
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
tensor_wrapper_name
))
{
auto
inplace_input_name
=
inplace_map
[
tensor_wrapper_name
];
auto
inplace_input_name
=
inplace_map
[
tensor_wrapper_name
];
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
),
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
),
LegalizeVarName
(
inplace_input_name
)
,
full_reserved
);
LegalizeVarName
(
inplace_input_name
));
}
else
{
}
else
{
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
),
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
),
LegalizeVarName
(
tensor_wrapper_name
)
,
full_reserved
);
LegalizeVarName
(
tensor_wrapper_name
));
}
}
}
}
}
}
...
@@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents(
...
@@ -2592,7 +2584,6 @@ static std::string GenerateGradNodeHeaderContents(
std
::
string
tensor_wrapper_arg_str
;
std
::
string
tensor_wrapper_arg_str
;
std
::
string
tensor_wrapper_body_str
;
std
::
string
tensor_wrapper_body_str
;
std
::
string
full_reserved_str
=
"full_reserved"
;
std
::
string
no_need_buffer_str
=
"false"
;
std
::
string
no_need_buffer_str
=
"false"
;
if
(
no_need_buffer_ins
.
count
(
tensor_wrapper_name
))
{
if
(
no_need_buffer_ins
.
count
(
tensor_wrapper_name
))
{
no_need_buffer_str
=
"true"
;
no_need_buffer_str
=
"true"
;
...
@@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents(
...
@@ -2610,12 +2601,12 @@ static std::string GenerateGradNodeHeaderContents(
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
"for(const auto& eager_tensor : %s) {
\n
"
"for(const auto& eager_tensor : %s) {
\n
"
" %s.emplace_back( egr::TensorWrapper(eager_tensor
, %s
"
" %s.emplace_back( egr::TensorWrapper(eager_tensor "
"
/*full_reserved*/
, %s) );
\n
"
", %s) );
\n
"
" }
\n
"
;
" }
\n
"
;
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
tensor_wrapper_name
,
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
tensor_wrapper_name
,
struct_tensor_wrapper_name
,
full_reserved_str
,
no_need_buffer_str
);
struct_tensor_wrapper_name
,
no_need_buffer_str
);
const
char
*
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
const
char
*
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
"for (auto tw: %s) {
\n
"
"for (auto tw: %s) {
\n
"
...
@@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents(
...
@@ -2636,22 +2627,20 @@ static std::string GenerateGradNodeHeaderContents(
TENSOR_WRAPPER_MEMBER_TEMPLATE
,
struct_tensor_wrapper_name
);
TENSOR_WRAPPER_MEMBER_TEMPLATE
,
struct_tensor_wrapper_name
);
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
"%s = egr::TensorWrapper(%s, %s
/*full_reserved*/, %s
);
\n
"
;
"%s = egr::TensorWrapper(%s, %s);
\n
"
;
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
struct_tensor_wrapper_name
,
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
struct_tensor_wrapper_name
,
tensor_wrapper_name
,
full_reserved_str
,
no_need_buffer_str
);
tensor_wrapper_name
,
no_need_buffer_str
);
const
char
*
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
" %s.clear();
\n
"
;
const
char
*
CLEAR_TENSOR_WRAPPER_TEMPLATE
=
" %s.clear();
\n
"
;
clear_tensor_wrappers_str
+=
paddle
::
string
::
Sprintf
(
clear_tensor_wrappers_str
+=
paddle
::
string
::
Sprintf
(
CLEAR_TENSOR_WRAPPER_TEMPLATE
,
struct_tensor_wrapper_name
);
CLEAR_TENSOR_WRAPPER_TEMPLATE
,
struct_tensor_wrapper_name
);
}
}
std
::
string
full_reserved_signature_str
=
"bool full_reserved"
;
const
char
*
SET_TENSOR_WRAPPER_TEMPLATE
=
const
char
*
SET_TENSOR_WRAPPER_TEMPLATE
=
" void SetTensorWrapper%s(%s
, %s
) {
\n
%s
\n
}
\n
"
;
" void SetTensorWrapper%s(%s) {
\n
%s
\n
}
\n
"
;
set_tensor_wrappers_str
+=
paddle
::
string
::
Sprintf
(
set_tensor_wrappers_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
tensor_wrapper_name
,
SET_TENSOR_WRAPPER_TEMPLATE
,
tensor_wrapper_name
,
tensor_wrapper_arg_str
,
full_reserved_signature_str
,
tensor_wrapper_arg_str
,
tensor_wrapper_body_str
);
tensor_wrapper_body_str
);
}
}
}
}
VLOG
(
6
)
<<
"Generated TensorWrapper"
;
VLOG
(
6
)
<<
"Generated TensorWrapper"
;
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
35c7c835
...
@@ -55,8 +55,8 @@ def ParseArguments():
...
@@ -55,8 +55,8 @@ def ParseArguments():
## Code Gen Templates ##
## Code Gen Templates ##
########################
########################
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE
=
\
SET_PLAIN_TENSOR_WRAPPER_TEMPLATE
=
\
""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}
, bool full_reserved
) {{
""" void SetTensorWrapper{}(const paddle::experimental::Tensor& {}) {{
{} = egr::TensorWrapper({},
full_reserved,
{});
{} = egr::TensorWrapper({}, {});
}}
}}
"""
"""
...
@@ -69,9 +69,9 @@ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
...
@@ -69,9 +69,9 @@ CLEAR_TENSOR_WRAPPER_TEMPLATE = \
"""
"""
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE
=
\
SET_VECTOR_TENSOR_WRAPPER_TEMPLATE
=
\
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}
, bool full_reserved
) {{
""" void SetTensorWrapper{}(const std::vector<paddle::experimental::Tensor>& {}) {{
for(const auto& eager_tensor : {}) {{
for(const auto& eager_tensor : {}) {{
{}.emplace_back(egr::TensorWrapper(eager_tensor,
full_reserved,
{}));
{}.emplace_back(egr::TensorWrapper(eager_tensor, {}));
}};
}};
}}
}}
"""
"""
...
@@ -676,9 +676,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
...
@@ -676,9 +676,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
if
is_fwd_input
:
if
is_fwd_input
:
if
is_optional
:
if
is_optional
:
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr())
, true
);"
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr()));"
else
:
else
:
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
, true
);"
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
);"
set_input_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
set_input_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
else
:
else
:
if
num_fwd_outputs
>
1
:
if
num_fwd_outputs
>
1
:
...
@@ -688,9 +688,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
...
@@ -688,9 +688,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
fwd_output_pos
=
forward_outputs_position_map
[
name
][
1
]
fwd_output_pos
=
forward_outputs_position_map
[
name
][
1
]
if
is_optional
:
if
is_optional
:
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr())
, false
);"
set_tensor_wrappers
=
f
"
{
indent
}
if(
{
name
}
.get_ptr() != nullptr) grad_node->SetTensorWrapper
{
name
}
(*(
{
name
}
.get_ptr()));"
else
:
else
:
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
, false
);"
set_tensor_wrappers
=
f
"
{
indent
}
grad_node->SetTensorWrapper
{
name
}
(
{
name
}
);"
set_output_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
set_output_tensor_wrappers_list
.
append
(
set_tensor_wrappers
)
set_input_tensor_wrappers_str
=
"
\n
"
.
join
(
set_input_tensor_wrappers_str
=
"
\n
"
.
join
(
set_input_tensor_wrappers_list
)
set_input_tensor_wrappers_list
)
...
...
paddle/fluid/eager/tensor_wrapper.h
浏览文件 @
35c7c835
...
@@ -34,7 +34,6 @@ class TensorWrapper {
...
@@ -34,7 +34,6 @@ class TensorWrapper {
public:
public:
TensorWrapper
()
=
default
;
TensorWrapper
()
=
default
;
explicit
TensorWrapper
(
const
paddle
::
experimental
::
Tensor
&
tensor
,
explicit
TensorWrapper
(
const
paddle
::
experimental
::
Tensor
&
tensor
,
bool
full_reserved
=
false
,
bool
no_need_buffer
=
false
)
{
bool
no_need_buffer
=
false
)
{
// set inplace_version_snapshot_ according to tensor's current inplace
// set inplace_version_snapshot_ according to tensor's current inplace
// version.
// version.
...
@@ -46,32 +45,12 @@ class TensorWrapper {
...
@@ -46,32 +45,12 @@ class TensorWrapper {
}
}
/**
/**
* Normally, we should
fully reserved all non-output or non-leaf fwd tensor
* Normally, we should
only save data and part of autograd_meta of fwd
*
here. And for fwd output tensor, we should not reserve its autogradmeta
,
*
tensor, and should not reserve its original grad_node
,
* to avoid recursive depends on GradNodeBase
* to avoid recursive
and additional
depends on GradNodeBase
* **/
* **/
full_reserved_
=
full_reserved
;
auto
*
tensor_autograd_meta
=
EagerUtils
::
nullable_autograd_meta
(
tensor
)
;
no_need_buffer_
=
no_need_buffer
;
no_need_buffer_
=
no_need_buffer
;
if
(
full_reserved_
)
{
VLOG
(
6
)
<<
"Fully reserved tensor: "
<<
tensor
.
name
();
intermidiate_tensor_
=
tensor
;
if
(
no_need_buffer_
)
{
if
(
phi
::
DenseTensor
::
classof
(
tensor
.
impl
().
get
()))
{
// Only Copy Meta
phi
::
DenseTensor
*
dense_tensor
=
static_cast
<
phi
::
DenseTensor
*>
(
tensor
.
impl
().
get
());
auto
tw_dense_tensor
=
std
::
make_shared
<
phi
::
DenseTensor
>
(
*
dense_tensor
);
tw_dense_tensor
->
clear
();
intermidiate_tensor_
.
set_impl
(
tw_dense_tensor
);
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Unrecognized tensor type for no_need_buffer feature"
));
}
}
return
;
}
// shallow copy tensor_impl here
// shallow copy tensor_impl here
if
(
no_need_buffer
)
{
if
(
no_need_buffer
)
{
if
(
phi
::
DenseTensor
::
classof
(
tensor
.
impl
().
get
()))
{
if
(
phi
::
DenseTensor
::
classof
(
tensor
.
impl
().
get
()))
{
...
@@ -89,10 +68,11 @@ class TensorWrapper {
...
@@ -89,10 +68,11 @@ class TensorWrapper {
intermidiate_tensor_
.
set_impl
(
tensor
.
impl
());
intermidiate_tensor_
.
set_impl
(
tensor
.
impl
());
}
}
// TODO(jiabin): This may has server performance issue
if
(
VLOG_IS_ON
(
7
))
{
intermidiate_tensor_
.
set_name
(
tensor
.
name
()
+
"@Saved"
);
// TODO(jiabin): This may has server performance issue
intermidiate_tensor_
.
set_name
(
tensor
.
name
()
+
"@Saved"
);
}
auto
*
tensor_autograd_meta
=
EagerUtils
::
nullable_autograd_meta
(
tensor
);
if
(
tensor_autograd_meta
)
{
if
(
tensor_autograd_meta
)
{
auto
autograd_meta
=
auto
autograd_meta
=
std
::
make_shared
<
AutogradMeta
>
(
*
tensor_autograd_meta
);
std
::
make_shared
<
AutogradMeta
>
(
*
tensor_autograd_meta
);
...
@@ -112,33 +92,28 @@ class TensorWrapper {
...
@@ -112,33 +92,28 @@ class TensorWrapper {
check_inplace_version
();
check_inplace_version
();
// if it's full_reserved just return the full copy of tensor
paddle
::
experimental
::
Tensor
recovered_tensor
=
intermidiate_tensor_
;
if
(
full_reserved_
)
{
return
intermidiate_tensor_
;
std
::
shared_ptr
<
GradNodeBase
>
new_grad_node
=
weak_grad_node_
.
lock
();
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
}
else
{
}
else
{
paddle
::
experimental
::
Tensor
recovered_tensor
=
intermidiate_tensor_
;
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empty GradNode"
;
}
auto
*
intermediate_autograd_meta
=
EagerUtils
::
nullable_autograd_meta
(
intermidiate_tensor_
);
std
::
shared_ptr
<
GradNodeBase
>
new_grad_node
=
weak_grad_node_
.
lock
();
if
(
intermediate_autograd_meta
)
{
auto
p_ab_autograd_meta
=
std
::
make_shared
<
AutogradMeta
>
(
*
intermediate_autograd_meta
);
if
(
new_grad_node
)
{
if
(
new_grad_node
)
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with GradNode "
p_ab_autograd_meta
->
SetGradNode
(
new_grad_node
);
<<
new_grad_node
->
name
()
<<
" addr: "
<<
new_grad_node
.
get
();
}
else
{
VLOG
(
3
)
<<
"Recovered TensorWrapper with Empty GradNode"
;
}
}
auto
*
intermediate_autograd_meta
=
recovered_tensor
.
set_autograd_meta
(
p_ab_autograd_meta
);
EagerUtils
::
nullable_autograd_meta
(
intermidiate_tensor_
);
if
(
intermediate_autograd_meta
)
{
auto
p_ab_autograd_meta
=
std
::
make_shared
<
AutogradMeta
>
(
*
intermediate_autograd_meta
);
if
(
new_grad_node
)
{
p_ab_autograd_meta
->
SetGradNode
(
new_grad_node
);
}
recovered_tensor
.
set_autograd_meta
(
p_ab_autograd_meta
);
}
return
recovered_tensor
;
}
}
return
recovered_tensor
;
}
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
void
clear
()
{
intermidiate_tensor_
.
reset
();
}
...
@@ -179,7 +154,6 @@ class TensorWrapper {
...
@@ -179,7 +154,6 @@ class TensorWrapper {
}
}
private:
private:
bool
full_reserved_
=
false
;
bool
no_need_buffer_
=
false
;
bool
no_need_buffer_
=
false
;
paddle
::
experimental
::
Tensor
intermidiate_tensor_
;
paddle
::
experimental
::
Tensor
intermidiate_tensor_
;
std
::
weak_ptr
<
egr
::
GradNodeBase
>
weak_grad_node_
;
std
::
weak_ptr
<
egr
::
GradNodeBase
>
weak_grad_node_
;
...
...
paddle/fluid/eager/tests/data_structure_tests/tensor_wrapper_test.cc
浏览文件 @
35c7c835
...
@@ -40,9 +40,11 @@ TEST(TensorWrapper, Basic) {
...
@@ -40,9 +40,11 @@ TEST(TensorWrapper, Basic) {
auto
auto_grad0
=
std
::
make_shared
<
egr
::
AutogradMeta
>
(
edge0
);
auto
auto_grad0
=
std
::
make_shared
<
egr
::
AutogradMeta
>
(
edge0
);
et1
.
set_autograd_meta
(
auto_grad0
);
et1
.
set_autograd_meta
(
auto_grad0
);
et1
.
set_name
(
"et1"
);
et1
.
set_name
(
"et1"
);
auto
tw0
=
egr
::
TensorWrapper
(
et1
,
true
);
auto
tw0
=
egr
::
TensorWrapper
(
et1
);
auto
recover_et1
=
tw0
.
recover
();
auto
recover_et1
=
tw0
.
recover
();
CHECK_EQ
(
recover_et1
.
name
(),
std
::
string
(
"et1"
));
if
(
VLOG_IS_ON
(
7
))
{
CHECK_EQ
(
recover_et1
.
name
(),
std
::
string
(
"et1@saved"
));
}
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et1
).
first
,
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et1
).
first
,
egr
::
EagerUtils
::
OutRankInfo
(
et1
).
first
);
egr
::
EagerUtils
::
OutRankInfo
(
et1
).
first
);
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et1
).
second
,
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et1
).
second
,
...
@@ -68,13 +70,15 @@ TEST(TensorWrapper, Basic) {
...
@@ -68,13 +70,15 @@ TEST(TensorWrapper, Basic) {
et2
.
set_autograd_meta
(
auto_grad1
);
et2
.
set_autograd_meta
(
auto_grad1
);
auto
tw1
=
egr
::
TensorWrapper
(
et2
,
false
);
auto
tw1
=
egr
::
TensorWrapper
(
et2
,
false
);
auto
recover_et2
=
tw1
.
recover
();
auto
recover_et2
=
tw1
.
recover
();
CHECK_EQ
(
recover_et2
.
name
(),
std
::
string
(
"et2@Saved"
));
if
(
VLOG_IS_ON
(
7
))
{
CHECK_EQ
(
recover_et2
.
name
(),
std
::
string
(
"et2@Saved"
));
}
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et2
).
first
,
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et2
).
first
,
egr
::
EagerUtils
::
OutRankInfo
(
et2
).
first
);
egr
::
EagerUtils
::
OutRankInfo
(
et2
).
first
);
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et2
).
second
,
CHECK_EQ
(
egr
::
EagerUtils
::
OutRankInfo
(
recover_et2
).
second
,
egr
::
EagerUtils
::
OutRankInfo
(
et2
).
second
);
egr
::
EagerUtils
::
OutRankInfo
(
et2
).
second
);
// Test Raw recover
// Test Raw recover
paddle
::
experimental
::
Tensor
et3
;
paddle
::
experimental
::
Tensor
et3
;
auto
tw2
=
egr
::
TensorWrapper
(
et3
,
true
);
auto
tw2
=
egr
::
TensorWrapper
(
et3
);
CHECK
(
tw2
.
recover
().
initialized
()
==
false
);
CHECK
(
tw2
.
recover
().
initialized
()
==
false
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录