Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
2b69eaea
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2b69eaea
编写于
6月 07, 2020
作者:
J
Jason
提交者:
GitHub
6月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #141 from PaddlePaddle/develop_hrnet
fixed input shape for hrnet
上级
9fe089f3
e3f56c10
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
6 deletion
+18
-6
paddlex/cv/models/hrnet.py
paddlex/cv/models/hrnet.py
+5
-3
paddlex/cv/nets/segmentation/hrnet.py
paddlex/cv/nets/segmentation/hrnet.py
+13
-3
未找到文件。
paddlex/cv/models/hrnet.py
浏览文件 @
2b69eaea
...
...
@@ -77,6 +77,7 @@ class HRNet(DeepLabv3p):
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
labels
=
None
self
.
fixed_input_shape
=
None
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
segmentation
.
HRNet
(
...
...
@@ -86,7 +87,8 @@ class HRNet(DeepLabv3p):
use_bce_loss
=
self
.
use_bce_loss
,
use_dice_loss
=
self
.
use_dice_loss
,
class_weight
=
self
.
class_weight
,
ignore_index
=
self
.
ignore_index
)
ignore_index
=
self
.
ignore_index
,
fixed_input_shape
=
self
.
fixed_input_shape
)
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
()
...
...
@@ -170,6 +172,6 @@ class HRNet(DeepLabv3p):
return
super
(
HRNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
,
resume_checkpoint
)
paddlex/cv/nets/segmentation/hrnet.py
浏览文件 @
2b69eaea
...
...
@@ -38,7 +38,8 @@ class HRNet(object):
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
ignore_index
=
255
):
ignore_index
=
255
,
fixed_input_shape
=
None
):
# dice_loss或bce_loss只适用两类分割中
if
num_classes
>
2
and
(
use_bce_loss
or
use_dice_loss
):
raise
ValueError
(
...
...
@@ -66,6 +67,7 @@ class HRNet(object):
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
fixed_input_shape
=
fixed_input_shape
self
.
backbone
=
paddlex
.
cv
.
nets
.
hrnet
.
HRNet
(
width
=
width
,
feature_maps
=
"stage4"
)
...
...
@@ -131,8 +133,16 @@ class HRNet(object):
def
generate_inputs
(
self
):
inputs
=
OrderedDict
()
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
if
self
.
fixed_input_shape
is
not
None
:
input_shape
=
[
None
,
3
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
]
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
else
:
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录