Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
08033c86
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
08033c86
编写于
2月 12, 2020
作者:
Z
Zeng Jinle
提交者:
GitHub
2月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix traced layer with non persistable vars, test=develop (#22552)
上级
31b54646
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
141 addition
and
17 deletion
+141
-17
paddle/fluid/imperative/jit/program_desc_tracer.cc
paddle/fluid/imperative/jit/program_desc_tracer.cc
+43
-11
paddle/fluid/imperative/jit/program_desc_tracer.h
paddle/fluid/imperative/jit/program_desc_tracer.h
+9
-2
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+4
-4
python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py
...unittests/test_imperative_trace_non_persistable_inputs.py
+85
-0
未找到文件。
paddle/fluid/imperative/jit/program_desc_tracer.cc
浏览文件 @
08033c86
...
@@ -25,6 +25,7 @@ namespace jit {
...
@@ -25,6 +25,7 @@ namespace jit {
class
UniqueBlockVarGenerator
{
class
UniqueBlockVarGenerator
{
public:
public:
UniqueBlockVarGenerator
(
const
VarDescMetaMap
&
all_vars
,
UniqueBlockVarGenerator
(
const
VarDescMetaMap
&
all_vars
,
const
VarBaseSet
&
non_exist_input_vars
,
framework
::
BlockDesc
*
block
);
framework
::
BlockDesc
*
block
);
std
::
string
NameOf
(
const
std
::
weak_ptr
<
VarBase
>
&
var
,
std
::
string
NameOf
(
const
std
::
weak_ptr
<
VarBase
>
&
var
,
...
@@ -33,7 +34,8 @@ class UniqueBlockVarGenerator {
...
@@ -33,7 +34,8 @@ class UniqueBlockVarGenerator {
private:
private:
void
InsertNewVarInBlock
(
const
std
::
weak_ptr
<
VarBase
>
&
var
,
void
InsertNewVarInBlock
(
const
std
::
weak_ptr
<
VarBase
>
&
var
,
const
framework
::
VarDesc
&
ref_desc
,
const
framework
::
VarDesc
&
ref_desc
,
const
std
::
string
&
name
);
const
std
::
string
&
name
,
bool
force_persistable
=
false
);
private:
private:
const
VarDescMetaMap
&
all_vars_
;
const
VarDescMetaMap
&
all_vars_
;
...
@@ -46,13 +48,18 @@ class UniqueBlockVarGenerator {
...
@@ -46,13 +48,18 @@ class UniqueBlockVarGenerator {
std
::
unordered_set
<
std
::
string
>
existing_names_
;
std
::
unordered_set
<
std
::
string
>
existing_names_
;
};
};
UniqueBlockVarGenerator
::
UniqueBlockVarGenerator
(
const
VarDescMetaMap
&
all_vars
,
UniqueBlockVarGenerator
::
UniqueBlockVarGenerator
(
framework
::
BlockDesc
*
block
)
const
VarDescMetaMap
&
all_vars
,
const
VarBaseSet
&
non_exist_input_vars
,
framework
::
BlockDesc
*
block
)
:
all_vars_
(
all_vars
),
block_
(
block
)
{
:
all_vars_
(
all_vars
),
block_
(
block
)
{
for
(
auto
&
var_pair
:
all_vars_
)
{
for
(
auto
&
var_pair
:
all_vars_
)
{
auto
*
var_desc
=
var_pair
.
second
.
get
();
auto
*
var_desc
=
var_pair
.
second
.
get
();
if
(
var_desc
->
Persistable
())
{
if
(
var_desc
->
Persistable
())
{
InsertNewVarInBlock
(
var_pair
.
first
,
*
var_desc
,
var_desc
->
Name
());
InsertNewVarInBlock
(
var_pair
.
first
,
*
var_desc
,
var_desc
->
Name
());
}
else
if
(
non_exist_input_vars
.
count
(
var_pair
.
first
.
lock
())
>
0
)
{
VLOG
(
10
)
<<
"Mark "
<<
var_desc
->
Name
()
<<
" as persistable"
;
InsertNewVarInBlock
(
var_pair
.
first
,
*
var_desc
,
var_desc
->
Name
(),
/*force_persistable=*/
true
);
}
}
}
}
}
}
...
@@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
...
@@ -90,12 +97,15 @@ std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
void
UniqueBlockVarGenerator
::
InsertNewVarInBlock
(
void
UniqueBlockVarGenerator
::
InsertNewVarInBlock
(
const
std
::
weak_ptr
<
VarBase
>
&
var
,
const
framework
::
VarDesc
&
var_desc
,
const
std
::
weak_ptr
<
VarBase
>
&
var
,
const
framework
::
VarDesc
&
var_desc
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
,
bool
force_persistable
)
{
var_to_name_
[
var
]
=
name
;
var_to_name_
[
var
]
=
name
;
existing_names_
.
insert
(
name
);
existing_names_
.
insert
(
name
);
auto
*
new_var_desc
=
block_
->
Var
(
name
);
auto
*
new_var_desc
=
block_
->
Var
(
name
);
*
new_var_desc
=
var_desc
;
*
new_var_desc
=
var_desc
;
new_var_desc
->
SetName
(
name
);
new_var_desc
->
SetName
(
name
);
if
(
force_persistable
)
{
new_var_desc
->
SetPersistable
(
true
);
}
}
}
void
ProgramDescTracer
::
InsertOp
(
const
std
::
string
&
type
,
void
ProgramDescTracer
::
InsertOp
(
const
std
::
string
&
type
,
...
@@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type,
...
@@ -106,13 +116,13 @@ void ProgramDescTracer::InsertOp(const std::string &type,
auto
&
new_op
=
ops_
.
back
();
auto
&
new_op
=
ops_
.
back
();
for
(
auto
&
pair
:
new_op
->
Inputs
())
{
for
(
auto
&
pair
:
new_op
->
Inputs
())
{
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
InsertVarIfNotExist
(
var
.
lock
());
InsertVarIfNotExist
(
var
.
lock
()
,
true
);
}
}
}
}
for
(
auto
&
pair
:
new_op
->
Outputs
())
{
for
(
auto
&
pair
:
new_op
->
Outputs
())
{
for
(
auto
&
var
:
pair
.
second
)
{
for
(
auto
&
var
:
pair
.
second
)
{
InsertVarIfNotExist
(
var
.
lock
());
InsertVarIfNotExist
(
var
.
lock
()
,
false
);
}
}
}
}
}
}
...
@@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
...
@@ -125,7 +135,12 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std
::
unique_ptr
<
framework
::
ProgramDesc
>
prog
(
new
framework
::
ProgramDesc
());
std
::
unique_ptr
<
framework
::
ProgramDesc
>
prog
(
new
framework
::
ProgramDesc
());
auto
*
block
=
prog
->
MutableBlock
(
0
);
auto
*
block
=
prog
->
MutableBlock
(
0
);
UniqueBlockVarGenerator
generator
(
vars_
,
block
);
auto
non_exist_vars_copy
=
non_exist_input_vars_
;
for
(
auto
&
feed_var
:
feed_vars
)
{
non_exist_vars_copy
.
erase
(
feed_var
);
}
UniqueBlockVarGenerator
generator
(
vars_
,
non_exist_vars_copy
,
block
);
std
::
vector
<
std
::
string
>
feed_var_names
;
std
::
vector
<
std
::
string
>
feed_var_names
;
for
(
auto
&
feed_var
:
feed_vars
)
{
for
(
auto
&
feed_var
:
feed_vars
)
{
...
@@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
...
@@ -164,21 +179,37 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
}
}
prog
->
Flush
();
prog
->
Flush
();
std
::
vector
<
std
::
shared_ptr
<
VarBase
>>
persistable_vars
(
non_exist_vars_copy
.
begin
(),
non_exist_vars_copy
.
end
());
for
(
auto
&
pair
:
vars_
)
{
if
(
pair
.
second
->
Persistable
())
{
auto
var
=
pair
.
first
.
lock
();
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Persistable var %s does not exist"
,
pair
.
second
->
Name
()));
persistable_vars
.
emplace_back
(
var
);
}
}
return
std
::
make_tuple
(
std
::
move
(
prog
),
std
::
move
(
feed_var_names
),
return
std
::
make_tuple
(
std
::
move
(
prog
),
std
::
move
(
feed_var_names
),
std
::
move
(
fetch_var_names
));
std
::
move
(
fetch_var_names
),
std
::
move
(
persistable_vars
));
}
}
void
ProgramDescTracer
::
InsertVarIfNotExist
(
void
ProgramDescTracer
::
InsertVarIfNotExist
(
const
std
::
shared_ptr
<
VarBase
>
&
new_var
)
{
const
std
::
shared_ptr
<
VarBase
>
&
new_var
,
bool
is_input
)
{
PADDLE_ENFORCE_NOT_NULL
(
new_var
);
PADDLE_ENFORCE_NOT_NULL
(
new_var
);
if
(
vars_
.
count
(
new_var
)
!=
0
)
return
;
if
(
vars_
.
count
(
new_var
)
!=
0
)
return
;
auto
new_var_desc
=
new
framework
::
VarDesc
(
""
);
auto
new_var_desc
=
new
framework
::
VarDesc
(
""
);
vars_
[
new_var
].
reset
(
new_var_desc
);
vars_
[
new_var
].
reset
(
new_var_desc
);
if
(
new_var
->
Persistable
())
{
if
(
new_var
->
Persistable
()
||
is_input
)
{
new_var_desc
->
SetName
(
new_var
->
Name
());
new_var_desc
->
SetName
(
new_var
->
Name
());
new_var_desc
->
SetPersistable
(
true
);
new_var_desc
->
SetPersistable
(
new_var
->
Persistable
());
if
(
!
new_var
->
Persistable
())
{
non_exist_input_vars_
.
insert
(
new_var
);
}
}
else
{
}
else
{
new_var_desc
->
SetPersistable
(
false
);
new_var_desc
->
SetPersistable
(
false
);
}
}
...
@@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist(
...
@@ -204,6 +235,7 @@ void ProgramDescTracer::InsertVarIfNotExist(
void
ProgramDescTracer
::
Reset
()
{
void
ProgramDescTracer
::
Reset
()
{
ops_
.
clear
();
ops_
.
clear
();
vars_
.
clear
();
vars_
.
clear
();
non_exist_input_vars_
.
clear
();
}
}
}
// namespace jit
}
// namespace jit
...
...
paddle/fluid/imperative/jit/program_desc_tracer.h
浏览文件 @
08033c86
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <set>
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <utility>
#include <utility>
...
@@ -34,10 +35,14 @@ using VarDescMetaMap =
...
@@ -34,10 +35,14 @@ using VarDescMetaMap =
std
::
map
<
std
::
weak_ptr
<
VarBase
>
,
std
::
unique_ptr
<
framework
::
VarDesc
>
,
std
::
map
<
std
::
weak_ptr
<
VarBase
>
,
std
::
unique_ptr
<
framework
::
VarDesc
>
,
std
::
owner_less
<
std
::
weak_ptr
<
VarBase
>>>
;
std
::
owner_less
<
std
::
weak_ptr
<
VarBase
>>>
;
using
VarBaseSet
=
std
::
set
<
std
::
shared_ptr
<
VarBase
>
,
std
::
owner_less
<
std
::
shared_ptr
<
VarBase
>>>
;
using
TracedProgramTuple
=
using
TracedProgramTuple
=
std
::
tuple
<
std
::
unique_ptr
<
framework
::
ProgramDesc
>
/*program*/
,
std
::
tuple
<
std
::
unique_ptr
<
framework
::
ProgramDesc
>
/*program*/
,
std
::
vector
<
std
::
string
>
/*feed_var_names*/
,
std
::
vector
<
std
::
string
>
/*feed_var_names*/
,
std
::
vector
<
std
::
string
>
/*fetch_var_names*/
>
;
std
::
vector
<
std
::
string
>
/*fetch_var_names*/
,
std
::
vector
<
std
::
shared_ptr
<
VarBase
>>
/*persistable_vars*/
>
;
class
ProgramDescTracer
{
class
ProgramDescTracer
{
DISABLE_COPY_AND_ASSIGN
(
ProgramDescTracer
);
DISABLE_COPY_AND_ASSIGN
(
ProgramDescTracer
);
...
@@ -58,11 +63,13 @@ class ProgramDescTracer {
...
@@ -58,11 +63,13 @@ class ProgramDescTracer {
void
Reset
();
void
Reset
();
private:
private:
void
InsertVarIfNotExist
(
const
std
::
shared_ptr
<
VarBase
>
&
new_var
);
void
InsertVarIfNotExist
(
const
std
::
shared_ptr
<
VarBase
>
&
new_var
,
bool
is_input
);
private:
private:
std
::
vector
<
std
::
unique_ptr
<
OpDescMeta
>>
ops_
;
std
::
vector
<
std
::
unique_ptr
<
OpDescMeta
>>
ops_
;
VarDescMetaMap
vars_
;
VarDescMetaMap
vars_
;
VarBaseSet
non_exist_input_vars_
;
};
};
}
// namespace jit
}
// namespace jit
...
...
python/paddle/fluid/dygraph/jit.py
浏览文件 @
08033c86
...
@@ -93,14 +93,14 @@ def _trace(layer,
...
@@ -93,14 +93,14 @@ def _trace(layer,
outputs
=
original_outputs
outputs
=
original_outputs
out_vars
=
[
var
for
var
in
outputs
]
out_vars
=
[
var
for
var
in
outputs
]
program_desc
,
feed_names
,
fetch_names
=
tracer
.
create_program_desc
(
program_desc
,
feed_names
,
fetch_names
,
parameters
=
tracer
.
create_program_desc
(
var_list
,
feed_prefix
,
out_vars
,
fetch_prefix
,
tmp_prefix
)
var_list
,
feed_prefix
,
out_vars
,
fetch_prefix
,
tmp_prefix
)
tracer
.
reset
()
tracer
.
reset
()
with
_dygraph_guard
(
None
):
with
_dygraph_guard
(
None
):
program
=
create_program_from_desc
(
program_desc
)
program
=
create_program_from_desc
(
program_desc
)
return
original_outputs
,
program
,
feed_names
,
fetch_names
return
original_outputs
,
program
,
feed_names
,
fetch_names
,
parameters
class
TracedLayer
(
object
):
class
TracedLayer
(
object
):
...
@@ -199,8 +199,8 @@ class TracedLayer(object):
...
@@ -199,8 +199,8 @@ class TracedLayer(object):
# save the static graph model for inference
# save the static graph model for inference
static_layer.save_inference_model(dirname='./saved_infer_model')
static_layer.save_inference_model(dirname='./saved_infer_model')
"""
"""
outs
,
prog
,
feed
,
fetch
=
_trace
(
layer
,
inputs
)
outs
,
prog
,
feed
,
fetch
,
parameters
=
_trace
(
layer
,
inputs
)
traced
=
TracedLayer
(
prog
,
layer
.
parameters
()
,
feed
,
fetch
)
traced
=
TracedLayer
(
prog
,
parameters
,
feed
,
fetch
)
return
outs
,
traced
return
outs
,
traced
def
set_strategy
(
self
,
build_strategy
=
None
,
exec_strategy
=
None
):
def
set_strategy
(
self
,
build_strategy
=
None
,
exec_strategy
=
None
):
...
...
python/paddle/fluid/tests/unittests/test_imperative_trace_non_persistable_inputs.py
0 → 100644
浏览文件 @
08033c86
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
numpy
as
np
import
six
import
os
class
SimpleFCLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
feature_size
,
batch_size
,
fc_size
):
super
(
SimpleFCLayer
,
self
).
__init__
()
self
.
_linear
=
fluid
.
dygraph
.
Linear
(
feature_size
,
fc_size
)
self
.
_offset
=
fluid
.
dygraph
.
to_variable
(
np
.
random
.
random
((
batch_size
,
fc_size
)).
astype
(
'float32'
))
def
forward
(
self
,
x
):
fc
=
self
.
_linear
(
x
)
return
fc
+
self
.
_offset
class
TestTracedLayerRecordNonPersistableInput
(
unittest
.
TestCase
):
def
test_main
(
self
):
traced_layer
=
None
with
fluid
.
dygraph
.
guard
():
feature_size
=
3
batch_size
=
4
fc_size
=
2
layer
=
SimpleFCLayer
(
feature_size
,
batch_size
,
fc_size
)
optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-3
,
parameter_list
=
layer
.
parameters
())
expected_persistable_vars
=
set
([
layer
.
_linear
.
weight
.
name
,
layer
.
_linear
.
bias
.
name
,
layer
.
_offset
.
name
])
for
_
in
six
.
moves
.
range
(
10
):
in_x
=
fluid
.
dygraph
.
to_variable
(
np
.
random
.
random
((
batch_size
,
feature_size
)).
astype
(
'float32'
))
if
traced_layer
is
None
:
dygraph_out
,
traced_layer
=
fluid
.
dygraph
.
TracedLayer
.
trace
(
layer
,
[
in_x
])
else
:
dygraph_out
=
layer
(
in_x
)
dygraph_out_numpy
=
dygraph_out
.
numpy
()
static_out
=
traced_layer
([
in_x
])[
0
]
self
.
assertTrue
(
np
.
array_equal
(
dygraph_out_numpy
,
static_out
))
loss
=
fluid
.
layers
.
reduce_mean
(
dygraph_out
)
loss
.
backward
()
optimizer
.
minimize
(
loss
)
del
layer
program
=
traced_layer
.
program
actual_persistable_vars
=
set
()
for
var
in
program
.
list_vars
():
if
var
.
persistable
:
actual_persistable_vars
.
add
(
var
.
name
)
self
.
assertEqual
(
actual_persistable_vars
,
expected_persistable_vars
)
dirname
=
'./traced_layer_test_non_persistable_vars'
traced_layer
.
save_inference_model
(
dirname
=
dirname
)
filenames
=
set
([
f
for
f
in
os
.
listdir
(
dirname
)
if
f
!=
'__model__'
])
self
.
assertEqual
(
filenames
,
expected_persistable_vars
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录