Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
613375e5
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
接近 2 年 前同步成功
通知
284
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
613375e5
编写于
9月 18, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix paddle2.0 adaptation problem
上级
73e68b4b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
28 addition
and
10 deletion
+28
-10
paddlehub/finetune/trainer.py
paddlehub/finetune/trainer.py
+8
-3
paddlehub/module/module.py
paddlehub/module/module.py
+5
-3
paddlehub/utils/log.py
paddlehub/utils/log.py
+4
-4
paddlehub/utils/platform.py
paddlehub/utils/platform.py
+11
-0
未找到文件。
paddlehub/finetune/trainer.py
浏览文件 @
613375e5
...
...
@@ -17,7 +17,7 @@ import os
import
pickle
import
time
from
collections
import
defaultdict
from
typing
import
Any
,
Callable
,
List
from
typing
import
Any
,
Callable
,
Generic
,
List
import
paddle
from
paddle.distributed
import
ParallelEnv
...
...
@@ -185,7 +185,7 @@ class Trainer(object):
timer
.
count
()
if
(
batch_idx
+
1
)
%
log_interval
==
0
and
self
.
local_rank
==
0
:
lr
=
self
.
optimizer
.
current_step
_lr
()
lr
=
self
.
optimizer
.
get
_lr
()
avg_loss
/=
log_interval
if
self
.
use_vdl
:
self
.
log_writer
.
add_scalar
(
tag
=
'TRAIN/loss'
,
step
=
timer
.
current_step
,
value
=
avg_loss
)
...
...
@@ -346,7 +346,12 @@ class Trainer(object):
optimizer(paddle.optimizer.Optimizer) : Optimizer used.
loss(paddle.Tensor) : Loss tensor.
'''
self
.
optimizer
.
minimize
(
loss
)
self
.
optimizer
.
step
()
self
.
learning_rate_step
(
epoch_idx
,
batch_idx
,
self
.
optimizer
.
get_lr
(),
loss
)
def
learning_rate_step
(
self
,
epoch_idx
:
int
,
batch_idx
:
int
,
learning_rate
:
Generic
,
loss
:
paddle
.
Tensor
):
if
isinstance
(
learning_rate
,
paddle
.
optimizer
.
_LRScheduler
):
learning_rate
.
step
()
def
optimizer_zero_grad
(
self
,
epoch_idx
:
int
,
batch_idx
:
int
,
optimizer
:
paddle
.
optimizer
.
Optimizer
):
'''
...
...
paddlehub/module/module.py
浏览文件 @
613375e5
...
...
@@ -135,7 +135,7 @@ class Module(object):
manager
=
LocalModuleManager
()
user_module_cls
=
manager
.
search
(
name
)
if
not
user_module_cls
or
not
user_module_cls
.
version
.
match
(
version
):
user_module_cls
=
manager
.
install
(
name
,
version
)
user_module_cls
=
manager
.
install
(
name
=
name
,
version
=
version
)
directory
=
manager
.
_get_normalized_path
(
user_module_cls
.
name
)
...
...
@@ -148,7 +148,8 @@ class Module(object):
user_module
=
user_module_cls
(
directory
=
directory
)
user_module
.
_initialize
(
**
kwargs
)
return
user_module
return
user_module_cls
(
directory
=
directory
,
**
kwargs
)
user_module_cls
.
directory
=
directory
return
user_module_cls
(
**
kwargs
)
@
classmethod
def
init_with_directory
(
cls
,
directory
:
str
,
**
kwargs
):
...
...
@@ -165,7 +166,8 @@ class Module(object):
user_module
=
user_module_cls
(
directory
=
directory
)
user_module
.
_initialize
(
**
kwargs
)
return
user_module
return
user_module_cls
(
directory
=
directory
,
**
kwargs
)
user_module_cls
.
directory
=
directory
return
user_module_cls
(
**
kwargs
)
@
classmethod
def
get_py_requirements
(
cls
):
...
...
paddlehub/utils/log.py
浏览文件 @
613375e5
...
...
@@ -73,7 +73,7 @@ class Logger(object):
self
.
__dict__
[
key
.
lower
()]
=
functools
.
partial
(
self
.
__call__
,
conf
[
'level'
])
self
.
format
=
colorlog
.
ColoredFormatter
(
'%(log_color)s[%(asctime)-15s] [%(levelname)8s] - %(message)s'
,
'%(log_color)s[%(asctime)-15s] [%(levelname)8s]
%(reset)s
- %(message)s'
,
log_colors
=
{
key
:
conf
[
'color'
]
for
key
,
conf
in
log_config
.
items
()})
...
...
@@ -178,13 +178,13 @@ class FormattedText(object):
======== ====================================
color(str) : Text color, default is None(depends on terminal configuration)
'''
_MAP
=
{
'red'
:
Fore
.
RED
,
'yellow'
:
Fore
.
YELLOW
,
'green'
:
Fore
.
GREEN
,
'blue'
:
Fore
.
BLUE
}
_MAP
=
{
'red'
:
Fore
.
RED
,
'yellow'
:
Fore
.
YELLOW
,
'green'
:
Fore
.
GREEN
,
'blue'
:
Fore
.
BLUE
,
'cyan'
:
Fore
.
CYAN
}
def
__init__
(
self
,
text
:
str
,
width
:
int
,
align
:
str
=
'<'
,
color
:
str
=
None
):
def
__init__
(
self
,
text
:
str
,
width
:
int
=
None
,
align
:
str
=
'<'
,
color
:
str
=
None
):
self
.
text
=
text
self
.
align
=
align
self
.
color
=
FormattedText
.
_MAP
[
color
]
if
color
else
color
self
.
width
=
width
self
.
width
=
width
if
width
else
len
(
self
.
text
)
def
__repr__
(
self
)
->
str
:
form
=
'{{:{}{}}}'
.
format
(
self
.
align
,
self
.
width
)
...
...
paddlehub/utils/platform.py
浏览文件 @
613375e5
...
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
platform
...
...
@@ -22,3 +23,13 @@ def get_platform() -> str:
def
is_windows
()
->
str
:
return
get_platform
().
lower
().
startswith
(
"windows"
)
def
get_platform_info
()
->
dict
:
return
{
'python_version'
:
'.'
.
join
(
map
(
str
,
sys
.
version_info
[
0
:
3
])),
'platform_version'
:
platform
.
version
(),
'platform_system'
:
platform
.
system
(),
'platform_architecture'
:
platform
.
architecture
(),
'platform_type'
:
platform
.
platform
()
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录