Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
dfa96d6d
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
280
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dfa96d6d
编写于
1月 25, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add a func to connect program
上级
0da9baf5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
96 addition
and
21 deletion
+96
-21
paddle_hub/__init__.py
paddle_hub/__init__.py
+1
-0
paddle_hub/module.py
paddle_hub/module.py
+67
-21
paddle_hub/utils.py
paddle_hub/utils.py
+28
-0
未找到文件。
paddle_hub/__init__.py
浏览文件 @
dfa96d6d
...
@@ -25,3 +25,4 @@ from paddle_hub.module import create_module
...
@@ -25,3 +25,4 @@ from paddle_hub.module import create_module
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub.signature
import
create_signature
from
paddle_hub.signature
import
create_signature
from
paddle_hub.version
import
__version__
from
paddle_hub.version
import
__version__
connect_program
=
ModuleUtils
.
connect_program
paddle_hub/module.py
浏览文件 @
dfa96d6d
...
@@ -23,14 +23,14 @@ import paddle.fluid as fluid
...
@@ -23,14 +23,14 @@ import paddle.fluid as fluid
import
numpy
as
np
import
numpy
as
np
import
tempfile
import
tempfile
import
os
import
os
import
pickle
import
copy
from
collections
import
defaultdict
from
collections
import
defaultdict
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub
import
module_desc_pb2
from
paddle_hub
import
module_desc_pb2
from
paddle_hub.logger
import
logger
from
paddle_hub.logger
import
logger
from
paddle_hub.signature
import
Signature
from
paddle_hub.signature
import
Signature
from
paddle_hub.utils
import
to_list
from
paddle_hub.utils
import
to_list
,
get_variable_info
from
paddle_hub.version
import
__version__
from
paddle_hub.version
import
__version__
__all__
=
[
"Module"
,
"ModuleConfig"
,
"ModuleUtils"
]
__all__
=
[
"Module"
,
"ModuleConfig"
,
"ModuleUtils"
]
...
@@ -235,25 +235,6 @@ class Module(object):
...
@@ -235,25 +235,6 @@ class Module(object):
word_dict
=
self
.
config
.
get_assets_vocab
()
word_dict
=
self
.
config
.
get_assets_vocab
()
return
list
(
map
(
lambda
x
:
word_dict
[
x
],
inputs
))
return
list
(
map
(
lambda
x
:
word_dict
[
x
],
inputs
))
def
set_input
(
self
,
input_dict
):
assert
isinstance
(
input_dict
,
dict
),
"input_dict must be a dict"
if
not
input_dict
:
logger
.
warning
(
"the input_dict is empty"
)
for
key
,
val
in
input_dict
.
items
():
assert
isinstance
(
val
,
fluid
.
framework
.
Variable
),
"the input_dict should be a dict with string-Variable pair"
program
=
val
.
block
.
program
assert
key
in
program
.
global_block
(
).
vars
,
"can't found input %s in the module"
%
key
input_var
=
val
output_var
=
program
.
global_block
().
var
(
key
)
program
.
global_block
().
_prepend_op
(
type
=
"assign"
,
inputs
=
{
'X'
:
input_var
},
outputs
=
{
'Out'
:
output_var
})
class
ModuleConfig
(
object
):
class
ModuleConfig
(
object
):
def
__init__
(
self
,
module_dir
,
module_name
=
None
):
def
__init__
(
self
,
module_dir
,
module_name
=
None
):
...
@@ -476,6 +457,71 @@ class ModuleUtils(object):
...
@@ -476,6 +457,71 @@ class ModuleUtils(object):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
@
staticmethod
def
connect_program
(
pre_program
,
next_program
,
input_dict
=
None
):
def
_copy_vars_and_ops_in_blocks
(
from_block
,
to_block
):
for
var
in
from_block
.
vars
:
var
=
from_block
.
var
(
var
)
var_info
=
copy
.
deepcopy
(
get_variable_info
(
var
))
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
to_block
.
create_parameter
(
**
var_info
)
else
:
to_block
.
create_var
(
**
var_info
)
for
op
in
from_block
.
ops
:
op_info
=
{
'type'
:
op
.
type
,
'inputs'
:
{
input
:
[
block
.
var
(
var
)
for
var
in
op
.
input
(
input
)]
for
input
in
op
.
input_names
},
'outputs'
:
{
output
:
[
block
.
var
(
var
)
for
var
in
op
.
output
(
output
)]
for
output
in
op
.
output_names
},
'attrs'
:
copy
.
deepcopy
(
op
.
all_attrs
())
}
to_block
.
append_op
(
**
op_info
)
assert
isinstance
(
pre_program
,
fluid
.
Program
),
"pre_program should be fluid.Program"
assert
isinstance
(
next_program
,
fluid
.
Program
),
"next_program should be fluid.Program"
new_program
=
pre_program
.
clone
()
if
input_dict
:
assert
isinstance
(
input_dict
,
dict
),
"the input_dict should be a dict with string-Variable pair"
for
key
,
var
in
input_dict
.
items
():
assert
isinstance
(
var
,
fluid
.
framework
.
Variable
),
"the input_dict should be a dict with string-Variable pair"
var_info
=
copy
.
deepcopy
(
get_variable_info
(
var
))
input_var
=
new_program
.
global_block
().
create_var
(
**
var_info
)
output_var
=
next_program
.
global_block
().
var
(
key
)
var_info
=
copy
.
deepcopy
(
get_variable_info
(
output_var
))
output_var
=
new_program
.
global_block
().
create_var
(
**
var_info
)
new_program
.
global_block
().
_prepend_op
(
type
=
"assign"
,
inputs
=
{
'X'
:
input_var
},
outputs
=
{
'Out'
:
output_var
})
block_map
=
{
0
:
0
}
logger
.
info
(
"start to connect program"
)
for
index
,
block
in
enumerate
(
next_program
.
blocks
):
if
block
.
idx
==
0
:
_copy_vars_and_ops_in_blocks
(
block
,
new_program
.
global_block
())
else
:
block_map
[
index
]
=
len
(
new_program
.
blocks
)
logger
.
info
(
"block_%d in next_program merge into block_%d in pre_program"
%
(
index
,
block_map
[
index
]))
new_block
=
new_program
.
_create_block
(
parent_idx
=
block_map
[
block
.
parent_idx
])
_copy_vars_and_ops_in_blocks
(
block
,
new_block
)
logger
.
info
(
"end of connect program"
)
return
new_program
@
staticmethod
@
staticmethod
def
remove_feed_fetch_op
(
program
):
def
remove_feed_fetch_op
(
program
):
""" remove feed and fetch operator and variable for fine-tuning
""" remove feed and fetch operator and variable for fine-tuning
...
...
paddle_hub/utils.py
浏览文件 @
dfa96d6d
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
def
to_list
(
input
):
def
to_list
(
input
):
...
@@ -25,3 +27,29 @@ def to_list(input):
...
@@ -25,3 +27,29 @@ def to_list(input):
input
=
[
input
]
input
=
[
input
]
return
input
return
input
def
get_variable_info
(
var
):
assert
isinstance
(
var
,
fluid
.
framework
.
Variable
),
"var should be a fluid.framework.Variable"
var_info
=
{
'type'
:
var
.
type
,
'name'
:
var
.
name
,
'dtype'
:
var
.
dtype
,
'lod_level'
:
var
.
lod_level
,
'shape'
:
var
.
shape
,
'stop_gradient'
:
var
.
stop_gradient
,
'is_data'
:
var
.
is_data
,
'error_clip'
:
var
.
error_clip
}
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
var_info
[
'trainable'
]
=
var
.
trainable
var_info
[
'optimize_attr'
]
=
var
.
optimize_attr
var_info
[
'regularizer'
]
=
var
.
regularizer
var_info
[
'gradient_clip_attr'
]
=
var
.
gradient_clip_attr
var_info
[
'do_model_average'
]
=
var
.
do_model_average
else
:
var_info
[
'persistable'
]
=
var
.
persistable
return
var_info
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录