Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
c3bae66b
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c3bae66b
编写于
6月 18, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update some
上级
6b4a7f02
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
4 addition
and
4 deletion
+4
-4
dygraph/datasets/__init__.py
dygraph/datasets/__init__.py
+1
-0
dygraph/train.py
dygraph/train.py
+2
-2
dygraph/utils/__init__.py
dygraph/utils/__init__.py
+0
-1
dygraph/utils/utils.py
dygraph/utils/utils.py
+1
-1
未找到文件。
dygraph/datasets/__init__.py
浏览文件 @
c3bae66b
...
@@ -13,3 +13,4 @@
...
@@ -13,3 +13,4 @@
# limitations under the License.
# limitations under the License.
from
.optic_disc_seg
import
OpticDiscSeg
from
.optic_disc_seg
import
OpticDiscSeg
from
.cityscapes
import
Cityscapes
dygraph/train.py
浏览文件 @
c3bae66b
...
@@ -78,7 +78,7 @@ def parse_args():
...
@@ -78,7 +78,7 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
'--pretrained_model'
,
'--pretrained_model'
,
dest
=
'pretrained_model'
,
dest
=
'pretrained_model'
,
help
=
'The path of pretr
ia
ned weight'
,
help
=
'The path of pretr
ai
ned weight'
,
type
=
str
,
type
=
str
,
default
=
None
)
default
=
None
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -161,7 +161,7 @@ def train(model,
...
@@ -161,7 +161,7 @@ def train(model,
optimizer
.
minimize
(
loss
)
optimizer
.
minimize
(
loss
)
model
.
clear_gradients
()
model
.
clear_gradients
()
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
logging
.
info
(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={}"
.
format
(
epoch
+
1
,
num_epochs
,
step
+
1
,
num_steps_each_epoch
,
epoch
+
1
,
num_epochs
,
step
+
1
,
len
(
batch_sampler
)
,
loss
.
numpy
()))
loss
.
numpy
()))
if
((
epoch
+
1
)
%
save_interval_epochs
==
0
if
((
epoch
+
1
)
%
save_interval_epochs
==
0
...
...
dygraph/utils/__init__.py
浏览文件 @
c3bae66b
...
@@ -16,4 +16,3 @@ from . import logging
...
@@ -16,4 +16,3 @@ from . import logging
from
.
import
download
from
.
import
download
from
.metrics
import
ConfusionMatrix
from
.metrics
import
ConfusionMatrix
from
.utils
import
*
from
.utils
import
*
from
.distributed
import
DistributedBatchSampler
dygraph/utils/utils.py
浏览文件 @
c3bae66b
...
@@ -48,8 +48,8 @@ def get_environ_info():
...
@@ -48,8 +48,8 @@ def get_environ_info():
def
load_pretrained_model
(
model
,
pretrained_model
):
def
load_pretrained_model
(
model
,
pretrained_model
):
logging
.
info
(
'Load pretrained model!'
)
if
pretrained_model
is
not
None
:
if
pretrained_model
is
not
None
:
logging
.
info
(
'Load pretrained model!'
)
if
os
.
path
.
exists
(
pretrained_model
):
if
os
.
path
.
exists
(
pretrained_model
):
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
para_state_dict
,
_
=
fluid
.
load_dygraph
(
ckpt_path
)
para_state_dict
,
_
=
fluid
.
load_dygraph
(
ckpt_path
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录