Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
6acb2dd4
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6acb2dd4
编写于
9月 15, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix module compat bug
上级
b33e4d14
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
436 addition
and
14 deletion
+436
-14
paddlehub/__init__.py
paddlehub/__init__.py
+13
-1
paddlehub/compat/module/module_desc.proto
paddlehub/compat/module/module_desc.proto
+81
-0
paddlehub/compat/module/module_v1.py
paddlehub/compat/module/module_v1.py
+4
-0
paddlehub/compat/paddle_utils.py
paddlehub/compat/paddle_utils.py
+156
-0
paddlehub/module/manager.py
paddlehub/module/manager.py
+2
-2
paddlehub/module/module.py
paddlehub/module/module.py
+43
-9
paddlehub/utils/log.py
paddlehub/utils/log.py
+2
-2
paddlehub/utils/parser.py
paddlehub/utils/parser.py
+75
-0
paddlehub/utils/utils.py
paddlehub/utils/utils.py
+60
-0
未找到文件。
paddlehub/__init__.py
浏览文件 @
6acb2dd4
...
@@ -13,9 +13,21 @@
...
@@ -13,9 +13,21 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
sys
__version__
=
'2.0.0a0'
__version__
=
'2.0.0a0'
from
paddlehub.utils
import
log
,
parser
,
utils
from
paddlehub.module
import
Module
from
paddlehub.module
import
Module
# In order to maintain the compatibility of the old version, we put the relevant
# compatible code in the paddlehub/compat package, and mapped some modules referenced
# in the old version
from
paddlehub.compat
import
paddle_utils
from
paddlehub.compat.module.processor
import
BaseProcessor
from
paddlehub.compat.module.processor
import
BaseProcessor
from
paddlehub.compat.module.nlp_module
import
NLPPredictionModule
,
TransformerModule
from
paddlehub.compat.type
import
DataType
from
paddlehub.compat.type
import
DataType
sys
.
modules
[
'paddlehub.io.parser'
]
=
parser
sys
.
modules
[
'paddlehub.common.logger'
]
=
log
sys
.
modules
[
'paddlehub.common.paddle_helper'
]
=
paddle_utils
sys
.
modules
[
'paddlehub.common.utils'
]
=
utils
paddlehub/compat/module/module_desc.proto
0 → 100644
浏览文件 @
6acb2dd4
// Copyright 2018 The Paddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
syntax
=
"proto3"
;
option
optimize_for
=
LITE_RUNTIME
;
package
paddlehub
.
module.desc
;
enum
DataType
{
NONE
=
0
;
INT
=
1
;
FLOAT
=
2
;
STRING
=
3
;
BOOLEAN
=
4
;
LIST
=
5
;
MAP
=
6
;
SET
=
7
;
OBJECT
=
8
;
}
message
KVData
{
map
<
string
,
DataType
>
key_type
=
1
;
map
<
string
,
ModuleAttr
>
data
=
2
;
}
message
ModuleAttr
{
// Basic type
DataType
type
=
1
;
int64
i
=
2
;
double
f
=
3
;
bool
b
=
4
;
string
s
=
5
;
KVData
map
=
6
;
KVData
list
=
7
;
KVData
set
=
8
;
KVData
object
=
9
;
//
string
name
=
10
;
string
info
=
11
;
}
// Feed Variable Description
message
FeedDesc
{
string
var_name
=
1
;
string
alias
=
2
;
};
// Fetch Variable Description
message
FetchDesc
{
string
var_name
=
1
;
string
alias
=
2
;
};
// Module Variable
message
ModuleVar
{
repeated
FetchDesc
fetch_desc
=
1
;
repeated
FeedDesc
feed_desc
=
2
;
}
// A Hub Module is stored in a directory with a file 'module_desc.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
message
ModuleDesc
{
// signature to module variable
map
<
string
,
ModuleVar
>
sign2var
=
2
;
ModuleAttr
attr
=
3
;
};
paddlehub/compat/module/module_v1.py
浏览文件 @
6acb2dd4
...
@@ -47,6 +47,10 @@ class ModuleV1(object):
...
@@ -47,6 +47,10 @@ class ModuleV1(object):
self
.
_generate_func
()
self
.
_generate_func
()
def
_load_processor
(
self
):
def
_load_processor
(
self
):
# Some module does not have a processor(e.g. ernie)
if
not
'processor_info'
in
self
.
desc
:
return
python_path
=
os
.
path
.
join
(
self
.
directory
,
'python'
)
python_path
=
os
.
path
.
join
(
self
.
directory
,
'python'
)
processor_name
=
self
.
desc
.
processor_info
processor_name
=
self
.
desc
.
processor_info
self
.
processor
=
utils
.
load_py_module
(
python_path
,
processor_name
)
self
.
processor
=
utils
.
load_py_module
(
python_path
,
processor_name
)
...
...
paddlehub/compat/paddle_utils.py
浏览文件 @
6acb2dd4
...
@@ -13,8 +13,64 @@
...
@@ -13,8 +13,64 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
from
typing
import
Callable
,
List
import
paddle
import
paddle
from
paddlehub.utils.utils
import
Version
dtype_map
=
{
paddle
.
device
.
core
.
VarDesc
.
VarType
.
FP32
:
"float32"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
FP64
:
"float64"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
FP16
:
"float16"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
INT32
:
"int32"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
INT16
:
"int16"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
INT64
:
"int64"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
BOOL
:
"bool"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
INT16
:
"int16"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
UINT8
:
"uint8"
,
paddle
.
device
.
core
.
VarDesc
.
VarType
.
INT8
:
"int8"
,
}
def
convert_dtype_to_string
(
dtype
:
str
)
->
paddle
.
device
.
core
.
VarDesc
.
VarType
:
if
dtype
in
dtype_map
:
return
dtype_map
[
dtype
]
raise
TypeError
(
"dtype shoule in %s"
%
list
(
dtype_map
.
keys
()))
def
get_variable_info
(
var
:
paddle
.
Variable
)
->
dict
:
if
not
isinstance
(
var
,
paddle
.
Variable
):
raise
TypeError
(
"var shoule be an instance of paddle.Variable"
)
var_info
=
{
'name'
:
var
.
name
,
'stop_gradient'
:
var
.
stop_gradient
,
'is_data'
:
var
.
is_data
,
'error_clip'
:
var
.
error_clip
,
'type'
:
var
.
type
}
try
:
var_info
[
'dtype'
]
=
convert_dtype_to_string
(
var
.
dtype
)
var_info
[
'lod_level'
]
=
var
.
lod_level
var_info
[
'shape'
]
=
var
.
shape
except
:
pass
if
isinstance
(
var
,
paddle
.
device
.
framework
.
Parameter
):
var_info
[
'trainable'
]
=
var
.
trainable
var_info
[
'optimize_attr'
]
=
var
.
optimize_attr
var_info
[
'regularizer'
]
=
var
.
regularizer
if
Version
(
paddle
.
__version__
)
<
'1.8'
:
var_info
[
'gradient_clip_attr'
]
=
var
.
gradient_clip_attr
var_info
[
'do_model_average'
]
=
var
.
do_model_average
else
:
var_info
[
'persistable'
]
=
var
.
persistable
return
var_info
def
remove_feed_fetch_op
(
program
:
paddle
.
static
.
Program
):
def
remove_feed_fetch_op
(
program
:
paddle
.
static
.
Program
):
'''Remove feed and fetch operator and variable for fine-tuning.'''
'''Remove feed and fetch operator and variable for fine-tuning.'''
...
@@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program):
...
@@ -39,3 +95,103 @@ def remove_feed_fetch_op(program: paddle.static.Program):
block
.
_remove_var
(
var
)
block
.
_remove_var
(
var
)
program
.
desc
.
flush
()
program
.
desc
.
flush
()
def
rename_var
(
block
:
paddle
.
device
.
framework
.
Block
,
old_name
:
str
,
new_name
:
str
):
'''
'''
for
op
in
block
.
ops
:
for
input_name
in
op
.
input_arg_names
:
if
input_name
==
old_name
:
op
.
_rename_input
(
old_name
,
new_name
)
for
output_name
in
op
.
output_arg_names
:
if
output_name
==
old_name
:
op
.
_rename_output
(
old_name
,
new_name
)
block
.
_rename_var
(
old_name
,
new_name
)
def
add_vars_prefix
(
program
:
paddle
.
static
.
Program
,
prefix
:
str
,
vars
:
List
[
paddle
.
Variable
]
=
None
,
excludes
:
Callable
=
None
):
'''
'''
block
=
program
.
global_block
()
vars
=
list
(
vars
)
if
vars
else
list
(
block
.
vars
.
keys
())
vars
=
[
var
for
var
in
vars
if
var
not
in
excludes
]
if
excludes
else
vars
for
var
in
vars
:
rename_var
(
block
,
var
,
prefix
+
var
)
def
remove_vars_prefix
(
program
:
paddle
.
static
.
Program
,
prefix
:
str
,
vars
:
List
[
paddle
.
Variable
]
=
None
,
excludes
:
Callable
=
None
):
'''
'''
block
=
program
.
global_block
()
vars
=
[
var
for
var
in
vars
if
var
.
startswith
(
prefix
)]
if
vars
else
[
var
for
var
in
block
.
vars
.
keys
()
if
var
.
startswith
(
prefix
)]
vars
=
[
var
for
var
in
vars
if
var
not
in
excludes
]
if
excludes
else
vars
for
var
in
vars
:
rename_var
(
block
,
var
,
var
.
replace
(
prefix
,
''
,
1
))
def
clone_program
(
origin_program
:
paddle
.
static
.
Program
,
for_test
:
bool
=
False
)
->
paddle
.
static
.
Program
:
dest_program
=
paddle
.
static
.
Program
()
_copy_vars_and_ops_in_blocks
(
origin_program
.
global_block
(),
dest_program
.
global_block
())
dest_program
=
dest_program
.
clone
(
for_test
=
for_test
)
if
not
for_test
:
for
name
,
var
in
origin_program
.
global_block
().
vars
.
items
():
dest_program
.
global_block
().
vars
[
name
].
stop_gradient
=
var
.
stop_gradient
return
dest_program
def
_copy_vars_and_ops_in_blocks
(
from_block
:
paddle
.
device
.
framework
.
Block
,
to_block
:
paddle
.
device
.
framework
.
Block
):
for
var
in
from_block
.
vars
:
var
=
from_block
.
var
(
var
)
var_info
=
copy
.
deepcopy
(
get_variable_info
(
var
))
if
isinstance
(
var
,
paddle
.
device
.
framework
.
Parameter
):
to_block
.
create_parameter
(
**
var_info
)
else
:
to_block
.
create_var
(
**
var_info
)
for
op
in
from_block
.
ops
:
all_attrs
=
op
.
all_attrs
()
if
'sub_block'
in
all_attrs
:
_sub_block
=
to_block
.
program
.
_create_block
()
_copy_vars_and_ops_in_blocks
(
all_attrs
[
'sub_block'
],
_sub_block
)
to_block
.
program
.
_rollback
()
new_attrs
=
{
'sub_block'
:
_sub_block
}
for
key
,
value
in
all_attrs
.
items
():
if
key
==
'sub_block'
:
continue
new_attrs
[
key
]
=
copy
.
deepcopy
(
value
)
else
:
new_attrs
=
copy
.
deepcopy
(
all_attrs
)
op_info
=
{
'type'
:
op
.
type
,
'inputs'
:
{
input
:
[
to_block
.
_find_var_recursive
(
var
)
for
var
in
op
.
input
(
input
)]
for
input
in
op
.
input_names
},
'outputs'
:
{
output
:
[
to_block
.
_find_var_recursive
(
var
)
for
var
in
op
.
output
(
output
)]
for
output
in
op
.
output_names
},
'attrs'
:
new_attrs
}
to_block
.
append_op
(
**
op_info
)
def
set_op_attr
(
program
:
paddle
.
static
.
Program
,
is_test
:
bool
=
False
):
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
if
not
op
.
has_attr
(
'is_test'
):
continue
op
.
_set_attr
(
'is_test'
,
is_test
)
paddlehub/module/manager.py
浏览文件 @
6acb2dd4
...
@@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception):
...
@@ -43,14 +43,14 @@ class HubModuleNotFoundError(Exception):
class
LocalModuleManager
(
object
):
class
LocalModuleManager
(
object
):
"""
'''
LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation,
LocalModuleManager is used to manage PaddleHub's local Module, which supports the installation, uninstallation,
and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the
and search of HubModule. LocalModuleManager is a singleton object related to the path, in other words, when the
LocalModuleManager object of the same home directory is generated multiple times, the same object is returned.
LocalModuleManager object of the same home directory is generated multiple times, the same object is returned.
Args:
Args:
home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules
home (str): The directory where PaddleHub modules are stored, the default is ~/.paddlehub/modules
"""
'''
_instance_map
=
{}
_instance_map
=
{}
def
__new__
(
cls
,
home
:
str
=
MODULE_HOME
):
def
__new__
(
cls
,
home
:
str
=
MODULE_HOME
):
...
...
paddlehub/module/module.py
浏览文件 @
6acb2dd4
...
@@ -17,9 +17,9 @@ import inspect
...
@@ -17,9 +17,9 @@ import inspect
import
importlib
import
importlib
import
os
import
os
import
sys
import
sys
from
typing
import
Callable
,
List
,
Optional
,
Generic
from
typing
import
Callable
,
Generic
,
List
,
Optional
from
paddlehub.utils
import
utils
from
paddlehub.utils
import
log
,
utils
from
paddlehub.compat.module.module_v1
import
ModuleV1
from
paddlehub.compat.module.module_v1
import
ModuleV1
...
@@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable:
...
@@ -58,9 +58,10 @@ def serving(func: Callable) -> Callable:
class
Module
(
object
):
class
Module
(
object
):
'''
'''
'''
'''
def
__new__
(
cls
,
name
:
str
=
None
,
directory
:
str
=
None
,
version
:
str
=
None
,
**
kwargs
):
def
__new__
(
cls
,
name
:
str
=
None
,
directory
:
str
=
None
,
version
:
str
=
None
,
**
kwargs
):
if
cls
.
__name__
==
'Module'
:
if
cls
.
__name__
==
'Module'
:
# This branch come from hub.Module(name='xxx')
# This branch come from hub.Module(name='xxx')
or hub.Module(directory='xxx')
if
name
:
if
name
:
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
,
**
kwargs
)
module
=
cls
.
init_with_name
(
name
=
name
,
version
=
version
,
**
kwargs
)
elif
directory
:
elif
directory
:
...
@@ -72,19 +73,19 @@ class Module(object):
...
@@ -72,19 +73,19 @@ class Module(object):
@
classmethod
@
classmethod
def
load
(
cls
,
directory
:
str
)
->
Generic
:
def
load
(
cls
,
directory
:
str
)
->
Generic
:
'''
'''
if
directory
.
endswith
(
os
.
sep
):
if
directory
.
endswith
(
os
.
sep
):
directory
=
directory
[:
-
1
]
directory
=
directory
[:
-
1
]
#
i
f module description file existed, try to load as ModuleV1
#
I
f module description file existed, try to load as ModuleV1
desc_file
=
os
.
path
.
join
(
directory
,
'module_desc.pb'
)
desc_file
=
os
.
path
.
join
(
directory
,
'module_desc.pb'
)
if
os
.
path
.
exists
(
desc_file
):
if
os
.
path
.
exists
(
desc_file
):
return
ModuleV1
.
load
(
desc_file
)
return
ModuleV1
.
load
(
desc_file
)
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
basename
=
os
.
path
.
split
(
directory
)[
-
1
]
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
dirname
=
os
.
path
.
join
(
*
list
(
os
.
path
.
split
(
directory
)[:
-
1
]))
py_module
=
utils
.
load_py_module
(
dirname
,
'{}.module'
.
format
(
basename
))
sys
.
path
.
insert
(
0
,
dirname
)
py_module
=
importlib
.
import_module
(
'{}.module'
.
format
(
basename
))
for
_item
,
_cls
in
inspect
.
getmembers
(
py_module
,
inspect
.
isclass
):
for
_item
,
_cls
in
inspect
.
getmembers
(
py_module
,
inspect
.
isclass
):
_item
=
py_module
.
__dict__
[
_item
]
_item
=
py_module
.
__dict__
[
_item
]
...
@@ -93,13 +94,14 @@ class Module(object):
...
@@ -93,13 +94,14 @@ class Module(object):
break
break
else
:
else
:
raise
InvalidHubModule
(
directory
)
raise
InvalidHubModule
(
directory
)
sys
.
path
.
pop
(
0
)
user_module_cls
.
directory
=
directory
user_module_cls
.
directory
=
directory
return
user_module_cls
return
user_module_cls
@
classmethod
@
classmethod
def
init_with_name
(
cls
,
name
:
str
,
version
:
str
=
None
,
**
kwargs
):
def
init_with_name
(
cls
,
name
:
str
,
version
:
str
=
None
,
**
kwargs
):
'''
'''
from
paddlehub.module.manager
import
LocalModuleManager
from
paddlehub.module.manager
import
LocalModuleManager
manager
=
LocalModuleManager
()
manager
=
LocalModuleManager
()
user_module_cls
=
manager
.
search
(
name
)
user_module_cls
=
manager
.
search
(
name
)
...
@@ -107,15 +109,39 @@ class Module(object):
...
@@ -107,15 +109,39 @@ class Module(object):
user_module_cls
=
manager
.
install
(
name
,
version
)
user_module_cls
=
manager
.
install
(
name
,
version
)
directory
=
manager
.
_get_normalized_path
(
name
)
directory
=
manager
.
_get_normalized_path
(
name
)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if
hasattr
(
user_module_cls
,
'_initialize'
):
log
.
logger
.
warning
(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module
=
user_module_cls
(
directory
=
directory
)
user_module
.
_initialize
(
**
kwargs
)
return
user_module
return
user_module_cls
(
directory
=
directory
,
**
kwargs
)
return
user_module_cls
(
directory
=
directory
,
**
kwargs
)
@
classmethod
@
classmethod
def
init_with_directory
(
cls
,
directory
:
str
,
**
kwargs
):
def
init_with_directory
(
cls
,
directory
:
str
,
**
kwargs
):
'''
'''
user_module_cls
=
cls
.
load
(
directory
)
user_module_cls
=
cls
.
load
(
directory
)
return
user_module_cls
(
**
kwargs
)
# The HubModule in the old version will use the _initialize method to initialize,
# this function will be obsolete in a future version
if
hasattr
(
user_module_cls
,
'_initialize'
):
log
.
logger
.
warning
(
'The _initialize method in HubModule will soon be deprecated, you can use the __init__() to handle the initialization of the object'
)
user_module
=
user_module_cls
(
directory
=
directory
)
user_module
.
_initialize
(
**
kwargs
)
return
user_module
return
user_module_cls
(
directory
=
directory
,
**
kwargs
)
@
classmethod
@
classmethod
def
get_py_requirements
(
cls
):
def
get_py_requirements
(
cls
):
'''
'''
req_file
=
os
.
path
.
join
(
cls
.
directory
,
'requirements.txt'
)
req_file
=
os
.
path
.
join
(
cls
.
directory
,
'requirements.txt'
)
if
not
os
.
path
.
exists
(
req_file
):
if
not
os
.
path
.
exists
(
req_file
):
return
[]
return
[]
...
@@ -125,6 +151,9 @@ class Module(object):
...
@@ -125,6 +151,9 @@ class Module(object):
class
RunModule
(
object
):
class
RunModule
(
object
):
'''
'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# Avoid module being initialized multiple times
# Avoid module being initialized multiple times
if
'_is_initialize'
in
self
.
__dict__
and
self
.
_is_initialize
:
if
'_is_initialize'
in
self
.
__dict__
and
self
.
_is_initialize
:
...
@@ -149,6 +178,8 @@ class RunModule(object):
...
@@ -149,6 +178,8 @@ class RunModule(object):
@
classmethod
@
classmethod
def
get_py_requirements
(
cls
)
->
List
[
str
]:
def
get_py_requirements
(
cls
)
->
List
[
str
]:
'''
'''
py_module
=
sys
.
modules
[
cls
.
__module__
]
py_module
=
sys
.
modules
[
cls
.
__module__
]
directory
=
os
.
path
.
dirname
(
py_module
.
__file__
)
directory
=
os
.
path
.
dirname
(
py_module
.
__file__
)
req_file
=
os
.
path
.
join
(
directory
,
'requirements.txt'
)
req_file
=
os
.
path
.
join
(
directory
,
'requirements.txt'
)
...
@@ -172,6 +203,9 @@ def moduleinfo(name: str,
...
@@ -172,6 +203,9 @@ def moduleinfo(name: str,
summary
:
str
=
None
,
summary
:
str
=
None
,
type
:
str
=
None
,
type
:
str
=
None
,
meta
=
None
)
->
Callable
:
meta
=
None
)
->
Callable
:
'''
'''
def
_wrapper
(
cls
:
Generic
)
->
Generic
:
def
_wrapper
(
cls
:
Generic
)
->
Generic
:
wrap_cls
=
cls
wrap_cls
=
cls
_meta
=
RunModule
if
not
meta
else
meta
_meta
=
RunModule
if
not
meta
else
meta
...
...
paddlehub/utils/log.py
浏览文件 @
6acb2dd4
...
@@ -170,8 +170,8 @@ class FormattedText(object):
...
@@ -170,8 +170,8 @@ class FormattedText(object):
self
.
width
=
width
self
.
width
=
width
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
form
=
'
:{}{
}'
.
format
(
self
.
align
,
self
.
width
)
form
=
'
{{:{}{}}
}'
.
format
(
self
.
align
,
self
.
width
)
text
=
(
'{'
+
form
+
'}'
)
.
format
(
self
.
text
)
text
=
form
.
format
(
self
.
text
)
if
not
self
.
color
:
if
not
self
.
color
:
return
text
return
text
return
self
.
color
+
text
+
Fore
.
RESET
return
self
.
color
+
text
+
Fore
.
RESET
...
...
paddlehub/utils/parser.py
0 → 100644
浏览文件 @
6acb2dd4
# coding:utf-8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
codecs
import
sys
from
typing
import
List
import
yaml
from
paddlehub.utils.utils
import
sys_stdin_encoding
class
CSVFileParser
(
object
):
def
parse
(
self
,
csv_file
:
str
)
->
dict
:
with
codecs
.
open
(
csv_file
,
'r'
,
sys_stdin_encoding
())
as
file
:
content
=
file
.
read
()
content
=
content
.
split
(
'
\n
'
)
self
.
title
=
content
[
0
].
split
(
','
)
self
.
content
=
{}
for
key
in
self
.
title
:
self
.
content
[
key
]
=
[]
for
text
in
content
[
1
:]:
if
(
text
==
''
):
continue
for
index
,
item
in
enumerate
(
text
.
split
(
','
)):
title
=
self
.
title
[
index
]
self
.
content
[
title
].
append
(
item
)
return
self
.
content
class
YAMLFileParser
(
object
):
def
parse
(
self
,
yaml_file
:
str
)
->
dict
:
with
codecs
.
open
(
yaml_file
,
'r'
,
sys_stdin_encoding
())
as
file
:
content
=
file
.
read
()
return
yaml
.
load
(
content
,
Loader
=
yaml
.
BaseLoader
)
class
TextFileParser
(
object
):
def
parse
(
self
,
txt_file
:
str
,
use_strip
:
bool
=
True
)
->
List
:
contents
=
[]
try
:
with
codecs
.
open
(
txt_file
,
'r'
,
encoding
=
'utf8'
)
as
file
:
for
line
in
file
:
if
use_strip
:
line
=
line
.
strip
()
if
line
:
contents
.
append
(
line
)
except
:
with
codecs
.
open
(
txt_file
,
'r'
,
encoding
=
'gbk'
)
as
file
:
for
line
in
file
:
if
use_strip
:
line
=
line
.
strip
()
if
line
:
contents
.
append
(
line
)
return
contents
csv_parser
=
CSVFileParser
()
yaml_parser
=
YAMLFileParser
()
txt_parser
=
TextFileParser
()
paddlehub/utils/utils.py
浏览文件 @
6acb2dd4
...
@@ -31,10 +31,12 @@ from urllib.parse import urlparse
...
@@ -31,10 +31,12 @@ from urllib.parse import urlparse
import
packaging.version
import
packaging.version
import
paddlehub.env
as
hubenv
import
paddlehub.env
as
hubenv
import
paddlehub.utils
as
utils
class
Version
(
packaging
.
version
.
Version
):
class
Version
(
packaging
.
version
.
Version
):
'''Extended implementation of packaging.version.Version'''
'''Extended implementation of packaging.version.Version'''
def
match
(
self
,
condition
:
str
)
->
bool
:
def
match
(
self
,
condition
:
str
)
->
bool
:
'''
'''
Determine whether the given condition are met
Determine whether the given condition are met
...
@@ -76,9 +78,35 @@ class Version(packaging.version.Version):
...
@@ -76,9 +78,35 @@ class Version(packaging.version.Version):
return
_comp
(
Version
(
version
))
return
_comp
(
Version
(
version
))
def
__lt__
(
self
,
other
):
if
isinstance
(
other
,
str
):
other
=
Version
(
other
)
return
super
().
__lt__
(
other
)
def
__le__
(
self
,
other
):
if
isinstance
(
other
,
str
):
other
=
Version
(
other
)
return
super
().
__le__
(
other
)
def
__gt__
(
self
,
other
):
if
isinstance
(
other
,
str
):
other
=
Version
(
other
)
return
super
().
__gt__
(
other
)
def
__ge__
(
self
,
other
):
if
isinstance
(
other
,
str
):
other
=
Version
(
other
)
return
super
().
__ge__
(
other
)
def
__eq__
(
self
,
other
):
if
isinstance
(
other
,
str
):
other
=
Version
(
other
)
return
super
().
__eq__
(
other
)
class
Timer
(
object
):
class
Timer
(
object
):
'''Calculate runing speed and estimated time of arrival(ETA)'''
'''Calculate runing speed and estimated time of arrival(ETA)'''
def
__init__
(
self
,
total_step
:
int
):
def
__init__
(
self
,
total_step
:
int
):
self
.
total_step
=
total_step
self
.
total_step
=
total_step
self
.
last_start_step
=
0
self
.
last_start_step
=
0
...
@@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
...
@@ -217,3 +245,35 @@ def load_py_module(python_path: str, py_module_name: str) -> types.ModuleType:
sys
.
path
.
pop
(
0
)
sys
.
path
.
pop
(
0
)
return
py_module
return
py_module
def
get_platform_default_encoding
()
->
str
:
'''
'''
if
utils
.
platform
.
is_windows
():
return
'gbk'
return
'utf8'
def
sys_stdin_encoding
()
->
str
:
'''
'''
encoding
=
sys
.
stdin
.
encoding
if
encoding
is
None
:
encoding
=
sys
.
getdefaultencoding
()
if
encoding
is
None
:
encoding
=
get_platform_default_encoding
()
return
encoding
def
sys_stdout_encoding
()
->
str
:
'''
'''
encoding
=
sys
.
stdout
.
encoding
if
encoding
is
None
:
encoding
=
sys
.
getdefaultencoding
()
if
encoding
is
None
:
encoding
=
get_platform_default_encoding
()
return
encoding
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录