Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
f71593d2
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f71593d2
编写于
5月 10, 2020
作者:
S
Steffy-zxf
提交者:
GitHub
5月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug that was resulted by add object detection task (#577)
上级
ae9edc1c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
32 deletion
+9
-32
paddlehub/finetune/task/base_task.py
paddlehub/finetune/task/base_task.py
+9
-32
未找到文件。
paddlehub/finetune/task/base_task.py
浏览文件 @
f71593d2
...
...
@@ -344,10 +344,6 @@ class BaseTask(object):
# set default phase
self
.
enter_phase
(
"train"
)
@
property
def
base_main_program
(
self
):
return
self
.
_base_main_program
@
contextlib
.
contextmanager
def
phase_guard
(
self
,
phase
):
self
.
enter_phase
(
phase
)
...
...
@@ -397,7 +393,7 @@ class BaseTask(object):
self
.
_build_env_start_event
()
self
.
env
.
is_inititalized
=
True
self
.
env
.
main_program
=
clone_program
(
self
.
base_main_program
,
for_test
=
False
)
self
.
_
base_main_program
,
for_test
=
False
)
self
.
env
.
startup_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
env
.
main_program
,
...
...
@@ -410,7 +406,6 @@ class BaseTask(object):
self
.
env
.
metrics
=
self
.
_add_metrics
()
if
self
.
is_predict_phase
or
self
.
is_test_phase
:
# Todo: paddle.fluid.core_avx.EnforceNotMet: Getting 'tensor_desc' is not supported by the type of var kCUDNNFwdAlgoCache. at
self
.
env
.
main_program
=
clone_program
(
self
.
env
.
main_program
,
for_test
=
True
)
hub
.
common
.
paddle_helper
.
set_op_attr
(
...
...
@@ -1063,10 +1058,8 @@ class BaseTask(object):
capacity
=
64
,
use_double_buffer
=
True
,
iterable
=
True
)
data_reader
=
data_loader
.
set_sample_list_generator
(
self
.
reader
,
self
.
places
)
# data_reader = data_loader.set_batch_generator(
# self.reader, places=self.places)
data_reader
=
data_loader
.
set_batch_generator
(
self
.
reader
,
places
=
self
.
places
)
else
:
data_feeder
=
fluid
.
DataFeeder
(
feed_list
=
self
.
feed_list
,
place
=
self
.
place
)
...
...
@@ -1083,28 +1076,12 @@ class BaseTask(object):
step_run_state
.
run_step
=
1
num_batch_examples
=
len
(
batch
)
if
self
.
return_numpy
==
2
:
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_to_be_run
,
feed
=
batch
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
# fetch_result = [x if isinstance(x,fluid.LoDTensor) else np.array(x) for x in fetch_result]
fetch_result
=
[
x
if
hasattr
(
x
,
'recursive_sequence_lengths'
)
else
np
.
array
(
x
)
for
x
in
fetch_result
]
elif
self
.
return_numpy
:
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_to_be_run
,
feed
=
batch
,
fetch_list
=
self
.
fetch_list
)
else
:
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_to_be_run
,
feed
=
batch
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
False
)
fetch_result
=
self
.
exe
.
run
(
self
.
main_program_to_be_run
,
feed
=
batch
,
fetch_list
=
self
.
fetch_list
,
return_numpy
=
self
.
return_numpy
)
if
not
self
.
return_numpy
:
fetch_result
=
[
np
.
array
(
x
)
for
x
in
fetch_result
]
for
index
,
result
in
enumerate
(
fetch_result
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录